mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
Refactors and cleanups.
This commit is contained in:
parent
77d42ed7e9
commit
c7b2c3b569
@ -781,6 +781,7 @@ class ACEAudio(LatentFormat):
|
||||
|
||||
class SeedVR2(LatentFormat):
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
|
||||
class ACEAudio15(LatentFormat):
|
||||
latent_channels = 64
|
||||
|
||||
@ -22,33 +22,14 @@ def _var_attention_output(out, heads, head_dim, skip_output_reshape):
|
||||
return out.reshape(-1, heads * head_dim)
|
||||
|
||||
|
||||
def _validate_split_cu_seqlens(name, cu_seqlens, token_count):
|
||||
if cu_seqlens.dtype not in (torch.int32, torch.int64):
|
||||
raise ValueError(f"{name} must use an integer dtype")
|
||||
if cu_seqlens.ndim != 1 or cu_seqlens.numel() < 2:
|
||||
raise ValueError(f"{name} must be a 1D tensor with at least two offsets")
|
||||
if cu_seqlens[0].item() != 0:
|
||||
raise ValueError(f"{name} must start at 0")
|
||||
if (cu_seqlens[1:] <= cu_seqlens[:-1]).any().item():
|
||||
raise ValueError(f"{name} must be strictly increasing")
|
||||
if cu_seqlens[-1].item() != token_count:
|
||||
raise ValueError(f"{name} does not match token count")
|
||||
|
||||
|
||||
def _split_indices(cu_seqlens):
|
||||
return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long)
|
||||
|
||||
|
||||
def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
|
||||
|
||||
_validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0])
|
||||
_validate_split_cu_seqlens("cu_seqlens_k", cu_seqlens_k, k.shape[0])
|
||||
if cu_seqlens_k[-1].item() != v.shape[0]:
|
||||
q_split_indices = cu_seqlens_q[1:-1]
|
||||
k_split_indices = cu_seqlens_k[1:-1]
|
||||
if k.shape[0] != v.shape[0]:
|
||||
raise ValueError("cu_seqlens_k does not match v token count")
|
||||
|
||||
q_split_indices = _split_indices(cu_seqlens_q)
|
||||
k_split_indices = _split_indices(cu_seqlens_k)
|
||||
q_splits = torch.tensor_split(q, q_split_indices, dim=0)
|
||||
k_splits = torch.tensor_split(k, k_split_indices, dim=0)
|
||||
v_splits = torch.tensor_split(v, k_split_indices, dim=0)
|
||||
|
||||
@ -45,7 +45,6 @@ def wavelet_decomposition(image: Tensor, levels: int = WAVELET_DECOMP_LEVELS):
|
||||
def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
|
||||
|
||||
if content_feat.shape != style_feat.shape:
|
||||
# Resize style to match content spatial dimensions
|
||||
if len(content_feat.shape) >= 3:
|
||||
style_feat = F.interpolate(
|
||||
style_feat,
|
||||
@ -54,12 +53,11 @@ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# Decompose both features into frequency components
|
||||
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
|
||||
del content_low_freq # Free memory immediately
|
||||
del content_low_freq
|
||||
|
||||
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
||||
del style_high_freq # Free memory immediately
|
||||
del style_high_freq
|
||||
|
||||
if content_high_freq.shape != style_low_freq.shape:
|
||||
style_low_freq = F.interpolate(
|
||||
@ -73,27 +71,23 @@ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
|
||||
|
||||
return content_high_freq.clamp_(-1.0, 1.0)
|
||||
|
||||
def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor:
|
||||
def _histogram_matching_channel(source: Tensor, reference: Tensor) -> Tensor:
|
||||
original_shape = source.shape
|
||||
|
||||
# Flatten
|
||||
source_flat = source.flatten()
|
||||
reference_flat = reference.flatten()
|
||||
|
||||
# Sort both arrays
|
||||
source_sorted, source_indices = torch.sort(source_flat)
|
||||
reference_sorted, _ = torch.sort(reference_flat)
|
||||
del reference_flat
|
||||
|
||||
# Quantile mapping
|
||||
n_source = len(source_sorted)
|
||||
n_reference = len(reference_sorted)
|
||||
|
||||
if n_source == n_reference:
|
||||
matched_sorted = reference_sorted
|
||||
else:
|
||||
# Interpolate reference to match source quantiles
|
||||
source_quantiles = torch.linspace(0, 1, n_source, device=device)
|
||||
source_quantiles = torch.linspace(0, 1, n_source, device=source.device)
|
||||
ref_indices = (source_quantiles * (n_reference - 1)).long()
|
||||
ref_indices.clamp_(0, n_reference - 1)
|
||||
matched_sorted = reference_sorted[ref_indices]
|
||||
@ -101,7 +95,6 @@ def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch
|
||||
|
||||
del source_sorted, source_flat
|
||||
|
||||
# Reconstruct using argsort (portable across CUDA/ROCm/MPS)
|
||||
inverse_indices = torch.argsort(source_indices)
|
||||
del source_indices
|
||||
matched_flat = matched_sorted[inverse_indices]
|
||||
@ -109,17 +102,14 @@ def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch
|
||||
|
||||
return matched_flat.reshape(original_shape)
|
||||
|
||||
def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor:
|
||||
"""Convert batch of CIELAB images to RGB color space."""
|
||||
def _lab_to_rgb_batch(lab: Tensor, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor:
|
||||
L, a, b = lab[:, 0], lab[:, 1], lab[:, 2]
|
||||
|
||||
# LAB to XYZ
|
||||
fy = (L + 16.0) / 116.0
|
||||
fx = a.div(500.0).add_(fy)
|
||||
fz = fy - b / 200.0
|
||||
del L, a, b
|
||||
|
||||
# XYZ transformation
|
||||
x = torch.where(
|
||||
fx > epsilon,
|
||||
torch.pow(fx, 3.0),
|
||||
@ -137,20 +127,16 @@ def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, eps
|
||||
)
|
||||
del fx, fy, fz
|
||||
|
||||
# Apply D65 white point (in-place)
|
||||
x.mul_(D65_WHITE_X)
|
||||
# y *= 1.00000 # (no-op, skip)
|
||||
z.mul_(D65_WHITE_Z)
|
||||
|
||||
xyz = torch.stack([x, y, z], dim=1)
|
||||
del x, y, z
|
||||
|
||||
# Matrix multiplication: XYZ -> RGB
|
||||
B, C, H, W = xyz.shape
|
||||
B, _, H, W = xyz.shape
|
||||
xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3)
|
||||
del xyz
|
||||
|
||||
# Ensure dtype consistency for matrix multiplication
|
||||
xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype)
|
||||
rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T)
|
||||
del xyz_flat
|
||||
@ -158,7 +144,6 @@ def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, eps
|
||||
rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
|
||||
del rgb_linear_flat
|
||||
|
||||
# Apply inverse gamma correction (delinearize)
|
||||
mask = rgb_linear > 0.0031308
|
||||
rgb = torch.where(
|
||||
mask,
|
||||
@ -169,9 +154,7 @@ def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, eps
|
||||
|
||||
return torch.clamp(rgb, 0.0, 1.0)
|
||||
|
||||
def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor:
|
||||
"""Convert batch of RGB images to CIELAB color space using D65 illuminant."""
|
||||
# Apply sRGB gamma correction (linearize)
|
||||
def _rgb_to_lab_batch(rgb: Tensor, matrix: Tensor, epsilon: float, kappa: float) -> Tensor:
|
||||
mask = rgb > 0.04045
|
||||
rgb_linear = torch.where(
|
||||
mask,
|
||||
@ -180,12 +163,10 @@ def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon
|
||||
)
|
||||
del mask
|
||||
|
||||
# Matrix multiplication: RGB -> XYZ
|
||||
B, C, H, W = rgb_linear.shape
|
||||
B, _, H, W = rgb_linear.shape
|
||||
rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3)
|
||||
del rgb_linear
|
||||
|
||||
# Ensure dtype consistency for matrix multiplication
|
||||
rgb_flat = rgb_flat.to(dtype=matrix.dtype)
|
||||
xyz_flat = torch.matmul(rgb_flat, matrix.T)
|
||||
del rgb_flat
|
||||
@ -193,12 +174,9 @@ def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon
|
||||
xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
|
||||
del xyz_flat
|
||||
|
||||
# Normalize by D65 white point (in-place)
|
||||
xyz[:, 0].div_(D65_WHITE_X) # X
|
||||
# xyz[:, 1] /= 1.00000 # Y (no-op, skip)
|
||||
xyz[:, 2].div_(D65_WHITE_Z) # Z
|
||||
xyz[:, 0].div_(D65_WHITE_X)
|
||||
xyz[:, 2].div_(D65_WHITE_Z)
|
||||
|
||||
# XYZ to LAB transformation
|
||||
epsilon_cubed = epsilon ** 3
|
||||
mask = xyz > epsilon_cubed
|
||||
f_xyz = torch.where(
|
||||
@ -208,10 +186,9 @@ def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon
|
||||
)
|
||||
del xyz, mask
|
||||
|
||||
# Extract channels and compute LAB
|
||||
L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100]
|
||||
a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127]
|
||||
b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127]
|
||||
L = f_xyz[:, 1].mul(116.0).sub_(16.0)
|
||||
a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0)
|
||||
b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0)
|
||||
del f_xyz
|
||||
|
||||
return torch.stack([L, a, b], dim=1)
|
||||
@ -232,13 +209,9 @@ def lab_color_transfer(
|
||||
)
|
||||
|
||||
device = content_feat.device
|
||||
|
||||
def ensure_float32_precision(c):
|
||||
orig_dtype = c.dtype
|
||||
c = c.float()
|
||||
return c, orig_dtype
|
||||
content_feat, original_dtype = ensure_float32_precision(content_feat)
|
||||
style_feat, _ = ensure_float32_precision(style_feat)
|
||||
original_dtype = content_feat.dtype
|
||||
content_feat = content_feat.float()
|
||||
style_feat = style_feat.float()
|
||||
|
||||
rgb_to_xyz_matrix = torch.tensor([
|
||||
[0.4124564, 0.3575761, 0.1804375],
|
||||
@ -258,39 +231,30 @@ def lab_color_transfer(
|
||||
content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
|
||||
style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
|
||||
|
||||
# Convert to LAB color space
|
||||
content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
|
||||
content_lab = _rgb_to_lab_batch(content_feat, rgb_to_xyz_matrix, epsilon, kappa)
|
||||
del content_feat
|
||||
|
||||
style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
|
||||
style_lab = _rgb_to_lab_batch(style_feat, rgb_to_xyz_matrix, epsilon, kappa)
|
||||
del style_feat, rgb_to_xyz_matrix
|
||||
|
||||
# Match chrominance channels (a*, b*) for accurate color transfer
|
||||
matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device)
|
||||
matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device)
|
||||
matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1])
|
||||
matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2])
|
||||
|
||||
# Handle luminance with weighted blending
|
||||
if luminance_weight < 1.0:
|
||||
# Partially match luminance for better overall color accuracy
|
||||
matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0], device)
|
||||
# Blend: preserve some content L* for detail, adopt some style L* for color
|
||||
matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0])
|
||||
result_L = content_lab[:, 0].mul(luminance_weight).add_(matched_L.mul(1.0 - luminance_weight))
|
||||
del matched_L
|
||||
else:
|
||||
# Fully preserve content luminance
|
||||
result_L = content_lab[:, 0]
|
||||
|
||||
del content_lab, style_lab
|
||||
|
||||
# Reconstruct LAB with corrected channels
|
||||
result_lab = torch.stack([result_L, matched_a, matched_b], dim=1)
|
||||
del result_L, matched_a, matched_b
|
||||
|
||||
# Convert back to RGB
|
||||
result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa)
|
||||
result_rgb = _lab_to_rgb_batch(result_lab, xyz_to_rgb_matrix, epsilon, kappa)
|
||||
del result_lab, xyz_to_rgb_matrix
|
||||
|
||||
# Convert back to [-1, 1] range (in-place)
|
||||
result = result_rgb.mul_(2.0).sub_(1.0)
|
||||
del result_rgb
|
||||
|
||||
|
||||
@ -1,34 +1,21 @@
|
||||
"""Named constants for the SeedVR2 integration, grouped by provenance.
|
||||
"""SeedVR2 constants."""
|
||||
|
||||
Provenance prefixes:
|
||||
- ``SEEDVR2_*`` - introduced by this integration (no external origin); rationale inline.
|
||||
- ``BYTEDANCE_*`` - ported from the official ByteDance-Seed/SeedVR release; each cites
|
||||
the upstream config/source path it was lifted from.
|
||||
- unprefixed standards (``ROPE_THETA``, ``CIELAB_*``, ``D65_*``) - published literature /
|
||||
ISO / CIE values; cite the standard.
|
||||
"""
|
||||
|
||||
SEEDVR2_7B_VID_DIM = 3072 # runtime 3b-vs-7b sentinel; tested against vid_dim.
|
||||
# (3072 is ByteDance's 7b vid_dim; the sentinel use is ours.)
|
||||
SEEDVR2_OOM_BACKOFF_DIVISOR = 2 # OOM retry backoff: halve the chunk and retry.
|
||||
SEEDVR2_DTYPE_BYTES_FLOOR = 4 # per-element byte floor for memory math (fp32 worst case).
|
||||
SEEDVR2_7B_MLP_CHUNK = 8192 # 7b MLP token-chunk to bound peak VRAM.
|
||||
SEEDVR2_7B_VID_DIM = 3072
|
||||
SEEDVR2_OOM_BACKOFF_DIVISOR = 2
|
||||
SEEDVR2_DTYPE_BYTES_FLOOR = 4
|
||||
SEEDVR2_7B_MLP_CHUNK = 8192
|
||||
SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS = 4096 # partial-RoPE application token-chunk.
|
||||
SEEDVR2_LATENT_CHANNELS = 16 # SeedVR2 latent channel count (== BYTEDANCE latent_channels).
|
||||
SEEDVR2_LATENT_CHANNELS = 16
|
||||
|
||||
# Color-correction memory model (fork tuning; per-frame VRAM estimate for chunk sizing)
|
||||
SEEDVR2_COLOR_MEM_HEADROOM = 0.75 # fraction of free VRAM usable per color-correction chunk.
|
||||
SEEDVR2_LAB_SCALE_MULTIPLIER = 13 # per-frame byte multiplier, LAB path.
|
||||
SEEDVR2_COLOR_MEM_HEADROOM = 0.75
|
||||
SEEDVR2_LAB_SCALE_MULTIPLIER = 13
|
||||
SEEDVR2_WAVELET_SCALE_MULTIPLIER = 10 # per-frame byte multiplier, wavelet path.
|
||||
SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 # per-frame byte multiplier, AdaIN path.
|
||||
SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
# ByteDance config / source (BYTEDANCE - cite ByteDance-Seed/SeedVR)
|
||||
# --------------------------------------------------------------------------------------
|
||||
BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57 (scaling_factor); latent denorm.
|
||||
BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 # infer.py (shifting_factor default); latent denorm shift.
|
||||
BYTEDANCE_VAE_CONV_MEM_GIB = 0.5 # configs_3b/main.yaml:54 (conv_max_mem).
|
||||
BYTEDANCE_VAE_NORM_MEM_GIB = 0.5 # configs_3b/main.yaml:55 (norm_max_mem).
|
||||
BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57.
|
||||
BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0
|
||||
BYTEDANCE_VAE_CONV_MEM_GIB = 0.5
|
||||
BYTEDANCE_VAE_NORM_MEM_GIB = 0.5
|
||||
BYTEDANCE_LOGVAR_CLAMP_MIN = -30.0 # video_vae_v3/modules/types.py:28.
|
||||
BYTEDANCE_LOGVAR_CLAMP_MAX = 20.0 # video_vae_v3/modules/types.py:28.
|
||||
BYTEDANCE_GN_CHUNKS_FP16 = 4 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp16).
|
||||
@ -42,18 +29,10 @@ BYTEDANCE_MAX_TEMPORAL_WINDOW = 30 # dit_v2/window.py:35 (max temporal wind
|
||||
BYTEDANCE_ROPE_MAX_FREQ = 256 # dit_v2/rope.py:31 (pixel-RoPE max frequency).
|
||||
BYTEDANCE_SINUSOIDAL_DIM = 256 # dit_3b/nadit.py:120 (timestep sinusoidal embed dim).
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
# Published standards (cite the literature)
|
||||
# --------------------------------------------------------------------------------------
|
||||
ROPE_THETA = 10000 # RoPE base; Su et al., "RoFormer", arXiv:2104.09864.
|
||||
|
||||
# CIELAB f(t) piecewise constants and D65 white point (CIE 15 colorimetry; CIE D65).
|
||||
CIELAB_DELTA = 6.0 / 29.0 # CIE 15 (delta).
|
||||
CIELAB_KAPPA = (29.0 / 3.0) ** 3 # CIE 15 (kappa).
|
||||
D65_WHITE_X = 0.95047 # CIE D65 standard illuminant Xn (Yn = 1).
|
||||
D65_WHITE_Z = 1.08883 # CIE D65 standard illuminant Zn.
|
||||
WAVELET_DECOMP_LEVELS = 5 # wavelet color-fix decomposition depth (GIMP/Krita; StableSR).
|
||||
|
||||
# NOTE: the sRGB<->XYZ D65 3x3 matrices (IEC 61966-2-1) remain inline in the color code and
|
||||
# are named (SRGB_TO_XYZ_D65 / XYZ_TO_SRGB_D65) during the color-module extraction, where the
|
||||
# exact existing coefficients move verbatim rather than being retyped here.
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Optional, Tuple, Union, List, Dict, Any, Callable
|
||||
import torch.nn.functional as F
|
||||
from math import ceil, pi
|
||||
import torch
|
||||
from itertools import chain
|
||||
from itertools import accumulate, chain
|
||||
from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding
|
||||
from comfy.ldm.seedvr.attention import optimized_var_attention
|
||||
from torch.nn.modules.utils import _triple
|
||||
@ -18,6 +18,7 @@ from comfy.ldm.seedvr.constants import (
|
||||
ROPE_THETA,
|
||||
SEEDVR2_7B_MLP_CHUNK,
|
||||
SEEDVR2_7B_VID_DIM,
|
||||
SEEDVR2_LATENT_CHANNELS,
|
||||
SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS,
|
||||
)
|
||||
import comfy.model_management
|
||||
@ -70,7 +71,7 @@ def repeat_concat_idx(
|
||||
vid_idx = torch.arange(vid_len.sum(), device=device)
|
||||
txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device)
|
||||
txt_repeat_list = txt_repeat.tolist()
|
||||
tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat)
|
||||
tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat_list)
|
||||
src_idx = torch.argsort(tgt_idx)
|
||||
txt_idx_len = len(tgt_idx) - len(vid_idx)
|
||||
repeat_txt_len = (txt_len * txt_repeat).tolist()
|
||||
@ -88,6 +89,9 @@ def repeat_concat_idx(
|
||||
lambda all: unconcat_coalesce(all),
|
||||
)
|
||||
|
||||
def cumulative_lengths(lengths):
|
||||
return [0, *accumulate(lengths)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MMArg:
|
||||
@ -110,16 +114,14 @@ def get_window_op(name: str):
|
||||
raise ValueError(f"Unknown windowing method: {name}")
|
||||
|
||||
|
||||
# -------------------------------- Windowing -------------------------------- #
|
||||
def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]):
|
||||
t, h, w = size
|
||||
resized_nt, resized_nh, resized_nw = num_windows
|
||||
#cal windows under 720p
|
||||
scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w))
|
||||
resized_h, resized_w = round(h * scale), round(w * scale)
|
||||
wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size.
|
||||
wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) # window size.
|
||||
nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww) # window size.
|
||||
wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw)
|
||||
wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt)
|
||||
nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww)
|
||||
return [
|
||||
(
|
||||
slice(it * wt, min((it + 1) * wt, t)),
|
||||
@ -137,19 +139,18 @@ def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int,
|
||||
def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]):
|
||||
t, h, w = size
|
||||
resized_nt, resized_nh, resized_nw = num_windows
|
||||
#cal windows under 720p
|
||||
scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w))
|
||||
resized_h, resized_w = round(h * scale), round(w * scale)
|
||||
wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size.
|
||||
wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) # window size.
|
||||
wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw)
|
||||
wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt)
|
||||
|
||||
st, sh, sw = ( # shift size.
|
||||
st, sh, sw = (
|
||||
0.5 if wt < t else 0,
|
||||
0.5 if wh < h else 0,
|
||||
0.5 if ww < w else 0,
|
||||
)
|
||||
nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) # window size.
|
||||
nt, nh, nw = ( # number of window.
|
||||
nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww)
|
||||
nt, nh, nw = (
|
||||
nt + 1 if st > 0 else 1,
|
||||
nh + 1 if sh > 0 else 1,
|
||||
nw + 1 if sw > 0 else 1,
|
||||
@ -175,7 +176,6 @@ class RotaryEmbedding(nn.Module):
|
||||
freqs_for = 'lang',
|
||||
theta = 10000,
|
||||
max_freq = 10,
|
||||
learned_freq = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -185,18 +185,14 @@ class RotaryEmbedding(nn.Module):
|
||||
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
||||
elif freqs_for == 'pixel':
|
||||
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
|
||||
else:
|
||||
raise ValueError(f"Unknown rotary frequency type: {freqs_for}")
|
||||
|
||||
self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)
|
||||
|
||||
self.learned_freq = learned_freq
|
||||
|
||||
# dummy for device
|
||||
|
||||
self.register_buffer('dummy', torch.tensor(0), persistent = False)
|
||||
self.register_buffer("freqs", freqs)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.dummy.device
|
||||
return self.freqs.device
|
||||
|
||||
def get_axial_freqs(
|
||||
self,
|
||||
@ -206,10 +202,9 @@ class RotaryEmbedding(nn.Module):
|
||||
Colon = slice(None)
|
||||
all_freqs = []
|
||||
|
||||
# handle offset
|
||||
|
||||
if exists(offsets):
|
||||
assert len(offsets) == len(dims)
|
||||
if len(offsets) != len(dims):
|
||||
raise ValueError(f"SeedVR2 rotary offsets length must match dims length, got {len(offsets)} and {len(dims)}.")
|
||||
|
||||
for ind, dim in enumerate(dims):
|
||||
|
||||
@ -224,7 +219,7 @@ class RotaryEmbedding(nn.Module):
|
||||
|
||||
pos = pos + offset
|
||||
|
||||
freqs = self.forward(pos, seq_len = dim)
|
||||
freqs = self.forward(pos)
|
||||
|
||||
all_axis = [None] * len(dims)
|
||||
all_axis[ind] = Colon
|
||||
@ -232,16 +227,12 @@ class RotaryEmbedding(nn.Module):
|
||||
new_axis_slice = (Ellipsis, *all_axis, Colon)
|
||||
all_freqs.append(freqs[new_axis_slice])
|
||||
|
||||
# concat all freqs
|
||||
|
||||
all_freqs = torch.broadcast_tensors(*all_freqs)
|
||||
return torch.cat(all_freqs, dim = -1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
t,
|
||||
seq_len: int | None = None,
|
||||
offset = 0
|
||||
):
|
||||
freqs = self.freqs
|
||||
|
||||
@ -258,9 +249,6 @@ class RotaryEmbeddingBase(nn.Module):
|
||||
freqs_for="pixel",
|
||||
max_freq=BYTEDANCE_ROPE_MAX_FREQ,
|
||||
)
|
||||
freqs = self.rope.freqs
|
||||
del self.rope.freqs
|
||||
self.rope.register_buffer("freqs", freqs.detach())
|
||||
|
||||
def get_axial_freqs(self, *dims):
|
||||
return self.rope.get_axial_freqs(*dims)
|
||||
@ -306,7 +294,7 @@ class NaRotaryEmbedding3d(RotaryEmbedding3d):
|
||||
freqs_for="pixel",
|
||||
max_freq=BYTEDANCE_ROPE_MAX_FREQ,
|
||||
)
|
||||
plain_rope = plain_rope.to(self.rope.dummy.device)
|
||||
plain_rope = plain_rope.to(self.rope.device)
|
||||
freq_list = []
|
||||
for f, h, w in shape.tolist():
|
||||
freqs = plain_rope.get_axial_freqs(f, h, w)
|
||||
@ -322,9 +310,6 @@ class MMRotaryEmbeddingBase(RotaryEmbeddingBase):
|
||||
freqs_for="lang",
|
||||
theta=ROPE_THETA,
|
||||
)
|
||||
freqs = self.rope.freqs
|
||||
del self.rope.freqs
|
||||
self.rope.register_buffer("freqs", freqs.detach())
|
||||
self.mm = True
|
||||
|
||||
def slice_at_dim(t, dim_slice: slice, *, dim):
|
||||
@ -333,8 +318,6 @@ def slice_at_dim(t, dim_slice: slice, *, dim):
|
||||
colons[dim] = dim_slice
|
||||
return t[tuple(colons)]
|
||||
|
||||
# rotary embedding helper functions
|
||||
|
||||
def rotate_half(x):
|
||||
x = x.reshape(*x.shape[:-1], x.shape[-1] // 2, 2)
|
||||
x1, x2 = x.unbind(dim = -1)
|
||||
@ -373,7 +356,6 @@ def _apply_seedvr2_rotary_emb(
|
||||
return torch.cat((t_left, t_middle, t_right), dim=-1).to(dtype)
|
||||
|
||||
def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert lucidrains-interleaved freqs to flux-canonical fp32 freqs_cis `[..., d/2, 2, 2]` (cos/-sin/sin/cos), per `comfy/ldm/flux/math.py:rope`."""
|
||||
angles = freqs_interleaved[..., ::2].float()
|
||||
cos = torch.cos(angles)
|
||||
sin = torch.sin(angles)
|
||||
@ -382,12 +364,6 @@ def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
def _apply_rope1_partial(t: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
"""Rotate the leading ``rot_d = 2 * freqs_cis.shape[-3]`` dims of ``t`` and pass the rest
|
||||
through; in-place for inference, cloned for training (autograd). Mirrors the legacy
|
||||
``apply_rotary_emb`` ``t_left``/``t_middle``/``t_right`` split: 3B ``rope_dim=128`` gives
|
||||
``42*3 = 126`` rotated of head_dim 128 (trailing 2 unrotated). Fast path skips the cat when
|
||||
``rot_d == t.shape[-1]``.
|
||||
"""
|
||||
out = t.clone() if t.requires_grad or comfy.model_management.in_training else t
|
||||
rot_d = 2 * freqs_cis.shape[-3]
|
||||
seq_len = out.shape[-2]
|
||||
@ -454,14 +430,13 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase):
|
||||
torch.Tensor,
|
||||
]:
|
||||
|
||||
# Calculate actual max dimensions needed for this batch
|
||||
max_temporal = 0
|
||||
max_height = 0
|
||||
max_width = 0
|
||||
max_txt_len = 0
|
||||
|
||||
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
|
||||
max_temporal = max(max_temporal, l + f) # Need up to l+f for temporal
|
||||
max_temporal = max(max_temporal, l + f)
|
||||
max_height = max(max_height, h)
|
||||
max_width = max(max_width, w)
|
||||
max_txt_len = max(max_txt_len, l)
|
||||
@ -475,7 +450,6 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase):
|
||||
).float()
|
||||
txt_freqs = self.get_axial_freqs(max_txt_len + 16)
|
||||
|
||||
# Now slice as before
|
||||
vid_freq_list, txt_freq_list = [], []
|
||||
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
|
||||
vid_freq = vid_freqs[l : l + f, :h, :w].reshape(-1, vid_freqs.size(-1))
|
||||
@ -485,13 +459,6 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase):
|
||||
vid_freqs_interleaved = torch.cat(vid_freq_list, dim=0)
|
||||
txt_freqs_interleaved = torch.cat(txt_freq_list, dim=0)
|
||||
|
||||
# Convert from lucidrains-interleaved layout `[θ0, θ0, θ1, θ1, ...]`
|
||||
# (produced by `repeat(freqs, '... n -> ... (n r)', r=2)` in the
|
||||
# upstream `RotaryEmbedding.forward`) to flux-canonical `freqs_cis`
|
||||
# in shape `[..., d/2, 2, 2]` with `cos/-sin/sin/cos` baked in.
|
||||
# Mirrors `comfy/ldm/flux/math.py:rope` (line 27) so the trailing
|
||||
# 2x2 is the per-frequency rotation matrix that
|
||||
# `comfy.ldm.flux.math.apply_rope1` expects.
|
||||
return _to_flux_freqs_cis(vid_freqs_interleaved), _to_flux_freqs_cis(txt_freqs_interleaved)
|
||||
|
||||
class MMModule(nn.Module):
|
||||
@ -507,8 +474,10 @@ class MMModule(nn.Module):
|
||||
self.shared_weights = shared_weights
|
||||
self.vid_only = vid_only
|
||||
if self.shared_weights:
|
||||
assert get_args("vid", args) == get_args("txt", args)
|
||||
assert get_kwargs("vid", kwargs) == get_kwargs("txt", kwargs)
|
||||
if get_args("vid", args) != get_args("txt", args):
|
||||
raise ValueError("SeedVR2 shared MMModule requires matching vid/txt args.")
|
||||
if get_kwargs("vid", kwargs) != get_kwargs("txt", kwargs):
|
||||
raise ValueError("SeedVR2 shared MMModule requires matching vid/txt kwargs.")
|
||||
self.all = module(*get_args("vid", args), **get_kwargs("vid", kwargs))
|
||||
else:
|
||||
self.vid = module(*get_args("vid", args), **get_kwargs("vid", kwargs))
|
||||
@ -543,6 +512,7 @@ def get_na_rope(rope_type: Optional[str], dim: int):
|
||||
return NaRotaryEmbedding3d(dim=dim)
|
||||
if rope_type == "mmrope3d":
|
||||
return NaMMRotaryEmbedding3d(dim=dim)
|
||||
raise ValueError(f"Unknown SeedVR2 rope type: {rope_type}")
|
||||
|
||||
class NaMMAttention(nn.Module):
|
||||
def __init__(
|
||||
@ -558,7 +528,6 @@ class NaMMAttention(nn.Module):
|
||||
rope_dim: int,
|
||||
shared_weights: bool,
|
||||
device, dtype, operations,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
dim = MMArg(vid_dim, txt_dim)
|
||||
@ -597,16 +566,19 @@ def window(
|
||||
):
|
||||
hid = unflatten(hid, hid_shape)
|
||||
hid = list(map(window_fn, hid))
|
||||
hid_windows = torch.as_tensor([len(x) for x in hid], device=hid_shape.device)
|
||||
hid, hid_shape = flatten(list(chain(*hid)))
|
||||
return hid, hid_shape, hid_windows
|
||||
hid_windows_list = [len(x) for x in hid]
|
||||
hid_windows = torch.as_tensor(hid_windows_list, device=hid_shape.device)
|
||||
hid = list(chain(*hid))
|
||||
hid_len_list = [math.prod(x.shape[:-1]) for x in hid]
|
||||
hid, hid_shape = flatten(hid)
|
||||
return hid, hid_shape, hid_windows, hid_len_list, hid_windows_list
|
||||
|
||||
def window_idx(
|
||||
hid_shape: torch.LongTensor, # (b n)
|
||||
window_fn: Callable[[torch.Tensor], List[torch.Tensor]],
|
||||
):
|
||||
hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1)
|
||||
tgt_idx, tgt_shape, tgt_windows = window(hid_idx, hid_shape, window_fn)
|
||||
tgt_idx, tgt_shape, tgt_windows, tgt_len_list, tgt_windows_list = window(hid_idx, hid_shape, window_fn)
|
||||
tgt_idx = tgt_idx.squeeze(-1)
|
||||
src_idx = torch.argsort(tgt_idx)
|
||||
return (
|
||||
@ -614,6 +586,8 @@ def window_idx(
|
||||
lambda hid: torch.index_select(hid, 0, src_idx),
|
||||
tgt_shape,
|
||||
tgt_windows,
|
||||
tgt_len_list,
|
||||
tgt_windows_list,
|
||||
)
|
||||
|
||||
class NaSwinAttention(NaMMAttention):
|
||||
@ -622,13 +596,15 @@ class NaSwinAttention(NaMMAttention):
|
||||
*args,
|
||||
window: Union[int, Tuple[int, int, int]],
|
||||
window_method: str,
|
||||
version: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.version_7b = kwargs.get("version", False)
|
||||
self.version_7b = version
|
||||
self.window = _triple(window)
|
||||
self.window_method = window_method
|
||||
assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window))
|
||||
if not all(isinstance(v, int) and v >= 0 for v in self.window):
|
||||
raise ValueError(f"SeedVR2 window must contain non-negative integers, got {self.window}.")
|
||||
|
||||
self.window_op = get_window_op(window_method)
|
||||
|
||||
@ -646,7 +622,6 @@ class NaSwinAttention(NaMMAttention):
|
||||
|
||||
vid_qkv, txt_qkv = self.proj_qkv(vid, txt)
|
||||
|
||||
# re-org the input seq for window attn
|
||||
cache_win = cache.namespace(f"{self.window_method}_{self.window}_sd3")
|
||||
|
||||
def make_window(x: torch.Tensor):
|
||||
@ -654,7 +629,7 @@ class NaSwinAttention(NaMMAttention):
|
||||
window_slices = self.window_op((t, h, w), self.window)
|
||||
return [x[st, sh, sw] for (st, sh, sw) in window_slices]
|
||||
|
||||
window_partition, window_reverse, window_shape, window_count = cache_win(
|
||||
window_partition, window_reverse, window_shape, window_count, vid_len_win_list, window_count_list = cache_win(
|
||||
"win_transform",
|
||||
lambda: window_idx(vid_shape, make_window),
|
||||
)
|
||||
@ -674,23 +649,21 @@ class NaSwinAttention(NaMMAttention):
|
||||
vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1))
|
||||
txt_len = txt_len.to(window_count.device)
|
||||
|
||||
# window rope
|
||||
if self.rope:
|
||||
if self.version_7b:
|
||||
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
|
||||
elif self.rope.mm:
|
||||
# repeat text q and k for window mmrope
|
||||
_, num_h, _ = txt_q.shape
|
||||
txt_q_repeat = txt_q.flatten(1, 2)
|
||||
txt_q_repeat = unflatten(txt_q_repeat, txt_shape)
|
||||
txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count)]
|
||||
txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count_list)]
|
||||
txt_q_repeat = list(chain(*txt_q_repeat))
|
||||
txt_q_repeat, txt_shape_repeat = flatten(txt_q_repeat)
|
||||
txt_q_repeat = txt_q_repeat.reshape(txt_q_repeat.shape[0], num_h, self.head_dim)
|
||||
|
||||
txt_k_repeat = txt_k.flatten(1, 2)
|
||||
txt_k_repeat = unflatten(txt_k_repeat, txt_shape)
|
||||
txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count)]
|
||||
txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count_list)]
|
||||
txt_k_repeat = list(chain(*txt_k_repeat))
|
||||
txt_k_repeat, _ = flatten(txt_k_repeat)
|
||||
txt_k_repeat = txt_k_repeat.reshape(txt_k_repeat.shape[0], num_h, self.head_dim)
|
||||
@ -702,7 +675,11 @@ class NaSwinAttention(NaMMAttention):
|
||||
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
|
||||
|
||||
txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count))
|
||||
all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win)
|
||||
txt_len_win_list = cache_win(
|
||||
"txt_len_list",
|
||||
lambda: [txt_len for txt_len, window_count in zip(txt_len.tolist(), window_count_list) for _ in range(window_count)],
|
||||
)
|
||||
all_len_win = cache_win("all_len", lambda: [vid_len + txt_len for vid_len, txt_len in zip(vid_len_win_list, txt_len_win_list)])
|
||||
concat_win, unconcat_win = cache_win(
|
||||
"mm_pnp", lambda: repeat_concat_idx(vid_len_win, txt_len, window_count)
|
||||
)
|
||||
@ -711,12 +688,8 @@ class NaSwinAttention(NaMMAttention):
|
||||
k=concat_win(vid_k, txt_k),
|
||||
v=concat_win(vid_v, txt_v),
|
||||
heads=self.heads, skip_reshape=True, skip_output_reshape=True,
|
||||
cu_seqlens_q=cache_win(
|
||||
"vid_seqlens_q", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int()
|
||||
),
|
||||
cu_seqlens_k=cache_win(
|
||||
"vid_seqlens_k", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int()
|
||||
),
|
||||
cu_seqlens_q=cache_win("vid_seqlens_q", lambda: cumulative_lengths(all_len_win)),
|
||||
cu_seqlens_k=cache_win("vid_seqlens_k", lambda: cumulative_lengths(all_len_win)),
|
||||
)
|
||||
vid_out, txt_out = unconcat_win(out)
|
||||
|
||||
@ -766,11 +739,11 @@ class SwiGLUMLP(nn.Module):
|
||||
return self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x))
|
||||
|
||||
def get_mlp(mlp_type: Optional[str] = "normal"):
|
||||
# 3b and 7b uses different mlp types
|
||||
if mlp_type == "normal":
|
||||
return MLP
|
||||
elif mlp_type == "swiglu":
|
||||
if mlp_type == "swiglu":
|
||||
return SwiGLUMLP
|
||||
raise ValueError(f"Unknown SeedVR2 MLP type: {mlp_type}")
|
||||
|
||||
class NaMMSRTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
@ -792,11 +765,12 @@ class NaMMSRTransformerBlock(nn.Module):
|
||||
rope_type: str,
|
||||
rope_dim: int,
|
||||
is_last_layer: bool,
|
||||
window: Union[int, Tuple[int, int, int]],
|
||||
window_method: str,
|
||||
version: bool,
|
||||
device, dtype, operations,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
version = kwargs.get("version", False)
|
||||
dim = MMArg(vid_dim, txt_dim)
|
||||
self.attn_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, device=device, dtype=dtype)
|
||||
|
||||
@ -811,8 +785,8 @@ class NaMMSRTransformerBlock(nn.Module):
|
||||
rope_type=rope_type,
|
||||
rope_dim=rope_dim,
|
||||
shared_weights=shared_weights,
|
||||
window=kwargs.pop("window", None),
|
||||
window_method=kwargs.pop("window_method", None),
|
||||
window=window,
|
||||
window_method=window_method,
|
||||
version=version,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
@ -930,12 +904,14 @@ class NaPatchOut(PatchOut):
|
||||
self,
|
||||
vid: torch.FloatTensor, # l c
|
||||
vid_shape: torch.LongTensor,
|
||||
cache: Cache = Cache(disable=True),
|
||||
cache: Optional[Cache] = None,
|
||||
vid_shape_before_patchify = None
|
||||
) -> Tuple[
|
||||
torch.FloatTensor,
|
||||
torch.LongTensor,
|
||||
]:
|
||||
if cache is None:
|
||||
cache = Cache(disable=True)
|
||||
|
||||
t, h, w = self.patch_size
|
||||
vid = self.proj(vid)
|
||||
@ -971,7 +947,10 @@ class PatchIn(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
t, h, w = self.patch_size
|
||||
if t > 1:
|
||||
assert vid.size(2) % t == 1
|
||||
if vid.size(2) % t != 1:
|
||||
raise ValueError(
|
||||
f"SeedVR2 patch input temporal size must satisfy T % {t} == 1, got {vid.size(2)}."
|
||||
)
|
||||
vid = torch.cat([vid[:, :, :1]] * (t - 1) + [vid], dim=2)
|
||||
b, c, Tt, Hh, Ww = vid.shape
|
||||
vid = vid.view(b, c, Tt // t, t, Hh // h, h, Ww // w, w).permute(0, 2, 4, 6, 3, 5, 7, 1).reshape(b, Tt // t, Hh // h, Ww // w, t * h * w * c)
|
||||
@ -983,8 +962,10 @@ class NaPatchIn(PatchIn):
|
||||
self,
|
||||
vid: torch.Tensor, # l c
|
||||
vid_shape: torch.LongTensor,
|
||||
cache: Cache = Cache(disable=True),
|
||||
cache: Optional[Cache] = None,
|
||||
) -> torch.Tensor:
|
||||
if cache is None:
|
||||
cache = Cache(disable=True)
|
||||
cache = cache.namespace("patch")
|
||||
vid_shape_before_patchify = cache("vid_shape_before_patchify", lambda: vid_shape)
|
||||
t, h, w = self.patch_size
|
||||
@ -1012,10 +993,11 @@ class AdaSingle(nn.Module):
|
||||
dim: int,
|
||||
emb_dim: int,
|
||||
layers: List[str],
|
||||
modes: List[str] = ["in", "out"],
|
||||
modes: Tuple[str, ...] = ("in", "out"),
|
||||
device = None, dtype = None,
|
||||
):
|
||||
assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim"
|
||||
if emb_dim != 6 * dim:
|
||||
raise ValueError(f"SeedVR2 AdaSingle requires emb_dim == 6 * dim, got emb_dim={emb_dim}, dim={dim}.")
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.emb_dim = emb_dim
|
||||
@ -1036,22 +1018,20 @@ class AdaSingle(nn.Module):
|
||||
emb: torch.FloatTensor, # b d
|
||||
layer: str,
|
||||
mode: str,
|
||||
cache: Cache = Cache(disable=True),
|
||||
cache: Optional[Cache] = None,
|
||||
branch_tag: str = "",
|
||||
hid_len: Optional[torch.LongTensor] = None, # b
|
||||
) -> torch.FloatTensor:
|
||||
if cache is None:
|
||||
cache = Cache(disable=True)
|
||||
idx = self.layers.index(layer)
|
||||
emb = emb.reshape(emb.shape[0], -1, len(self.layers), 3)[:, :, idx, :]
|
||||
emb = expand_dims(emb, 1, hid.ndim + 1)
|
||||
|
||||
if hid_len is not None:
|
||||
slice_inputs = lambda x, dim: x
|
||||
emb = cache(
|
||||
f"emb_repeat_{idx}_{branch_tag}",
|
||||
lambda: slice_inputs(
|
||||
torch.repeat_interleave(emb, hid_len, dim=0),
|
||||
dim=0,
|
||||
),
|
||||
lambda: torch.repeat_interleave(emb, hid_len, dim=0),
|
||||
)
|
||||
|
||||
shiftA, scaleA, gateA = emb.unbind(-1)
|
||||
@ -1069,7 +1049,7 @@ class AdaSingle(nn.Module):
|
||||
else:
|
||||
return hid.mul_(gateA)
|
||||
|
||||
raise NotImplementedError
|
||||
raise ValueError(f"Unknown AdaSingle mode: {mode}")
|
||||
|
||||
|
||||
class TimeEmbedding(nn.Module):
|
||||
@ -1117,7 +1097,8 @@ def flatten(
|
||||
torch.FloatTensor, # (L c)
|
||||
torch.LongTensor, # (b n)
|
||||
]:
|
||||
assert len(hid) > 0
|
||||
if len(hid) == 0:
|
||||
raise ValueError("SeedVR2 flatten requires at least one tensor.")
|
||||
shape = torch.as_tensor([x.shape[:-1] for x in hid], device=hid[0].device)
|
||||
hid = torch.cat([x.flatten(0, -2) for x in hid])
|
||||
return hid, shape
|
||||
@ -1140,7 +1121,7 @@ class NaDiT(nn.Module):
|
||||
num_layers,
|
||||
mlp_type,
|
||||
vid_in_channels = 33,
|
||||
vid_out_channels = 16,
|
||||
vid_out_channels = SEEDVR2_LATENT_CHANNELS,
|
||||
vid_dim = 2560,
|
||||
txt_in_dim = 5120,
|
||||
heads = 20,
|
||||
@ -1148,15 +1129,17 @@ class NaDiT(nn.Module):
|
||||
mm_layers = 10,
|
||||
expand_ratio = 4,
|
||||
qk_bias = False,
|
||||
patch_size = [ 1,2,2 ],
|
||||
patch_size = (1, 2, 2),
|
||||
rope_dim = 128,
|
||||
rope_type = "mmrope3d",
|
||||
vid_out_norm: Optional[str] = None,
|
||||
image_model = None,
|
||||
device = None,
|
||||
dtype = None,
|
||||
operations = None,
|
||||
**kwargs,
|
||||
):
|
||||
if image_model not in (None, "seedvr2"):
|
||||
raise ValueError(f"SeedVR2 NaDiT expected image_model='seedvr2', got {image_model!r}.")
|
||||
self._7b_version = vid_dim == SEEDVR2_7B_VID_DIM
|
||||
if self._7b_version:
|
||||
rope_type = "rope3d"
|
||||
@ -1212,14 +1195,13 @@ class NaDiT(nn.Module):
|
||||
rope_dim = rope_dim,
|
||||
window=window[i],
|
||||
window_method=window_method[i],
|
||||
version = self._7b_version,
|
||||
is_last_layer=(i == num_layers - 1) and not self._7b_version,
|
||||
rope_type = rope_type,
|
||||
shared_weights=not (
|
||||
(i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i]
|
||||
),
|
||||
version = self._7b_version,
|
||||
operations = operations,
|
||||
**kwargs,
|
||||
**factory_kwargs
|
||||
)
|
||||
for i in range(num_layers)
|
||||
@ -1272,13 +1254,17 @@ class NaDiT(nn.Module):
|
||||
first = cond_or_uncond[0]
|
||||
return all(entry == first for entry in cond_or_uncond)
|
||||
|
||||
@staticmethod
|
||||
def _check_seedvr2_video_latent(x, channels, name):
|
||||
if x.ndim != 5:
|
||||
raise ValueError(f"SeedVR2 expected {name} to be 5-D native latent, got shape {tuple(x.shape)}.")
|
||||
if x.shape[1] != channels:
|
||||
raise ValueError(f"SeedVR2 expected {name} channels to be {channels}, got shape {tuple(x.shape)}.")
|
||||
return x
|
||||
|
||||
def _swap_pos_neg_halves(self, out, cond_or_uncond=None):
|
||||
if NaDiT._seedvr2_is_single_conditioning_branch(cond_or_uncond):
|
||||
return out
|
||||
# ``dim=0`` is explicit on both calls. The contract is "split
|
||||
# the batch axis into two halves and swap them"; making the
|
||||
# axis load-bearing in source guards against silent drift if a
|
||||
# future refactor reorders tensor axes.
|
||||
pos, neg = out.chunk(2, dim=0)
|
||||
return torch.cat([neg, pos], dim=0)
|
||||
|
||||
@ -1294,9 +1280,15 @@ class NaDiT(nn.Module):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
conditions = kwargs.get("condition")
|
||||
b, tc, h, w = x.shape
|
||||
x = x.view(b, 16, -1, h, w)
|
||||
conditions = conditions.view(b, 17, -1, h, w)
|
||||
if conditions is None:
|
||||
raise ValueError("SeedVR2 requires conditioning latents from the SeedVR2Conditioning node.")
|
||||
x = self._check_seedvr2_video_latent(x, SEEDVR2_LATENT_CHANNELS, "latent")
|
||||
conditions = self._check_seedvr2_video_latent(conditions, SEEDVR2_LATENT_CHANNELS + 1, "conditioning")
|
||||
b, _, t, h, w = x.shape
|
||||
if conditions.shape[0] != b or conditions.shape[2:] != (t, h, w):
|
||||
raise ValueError(
|
||||
f"SeedVR2 conditioning shape must match latent batch/temporal/spatial dimensions; got latent {tuple(x.shape)} and conditioning {tuple(conditions.shape)}."
|
||||
)
|
||||
x = x.movedim(1, -1)
|
||||
conditions = conditions.movedim(1, -1)
|
||||
cache = Cache(disable=disable_cache)
|
||||
@ -1361,7 +1353,6 @@ class NaDiT(nn.Module):
|
||||
|
||||
vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify)
|
||||
vid = unflatten(vid, vid_shape)
|
||||
out = torch.stack(vid)
|
||||
out = torch.stack(vid)
|
||||
out = out.movedim(-1, 1)
|
||||
out = out.reshape(out.shape[0], out.shape[1] * out.shape[2], out.shape[3], out.shape[4])
|
||||
return self._swap_pos_neg_halves(out, transformer_options.get("cond_or_uncond"))
|
||||
|
||||
@ -62,7 +62,6 @@ def tiled_vae(
|
||||
temporal_size=16,
|
||||
temporal_overlap=0,
|
||||
encode=True,
|
||||
**kwargs,
|
||||
):
|
||||
if x.ndim != 5:
|
||||
x = x.unsqueeze(2)
|
||||
@ -166,8 +165,8 @@ def tiled_vae(
|
||||
|
||||
if single_spatial_tile:
|
||||
result = tile_out[:, :, :target_d, :target_h, :target_w]
|
||||
if result.device != x.device:
|
||||
result = result.to(x.device).to(x.dtype)
|
||||
if result.device != x.device or result.dtype != x.dtype:
|
||||
result = result.to(device=x.device, dtype=x.dtype)
|
||||
if x.shape[2] == 1 and sf_t == 1:
|
||||
result = result.squeeze(2)
|
||||
bar.update(1)
|
||||
@ -221,8 +220,8 @@ def tiled_vae(
|
||||
|
||||
result.div_(count.clamp(min=1e-6))
|
||||
|
||||
if result.device != x.device:
|
||||
result = result.to(x.device).to(x.dtype)
|
||||
if result.device != x.device or result.dtype != x.dtype:
|
||||
result = result.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
if x.shape[2] == 1 and sf_t == 1:
|
||||
result = result.squeeze(2)
|
||||
@ -256,15 +255,18 @@ class MemoryState(Enum):
|
||||
UNSET = 3
|
||||
|
||||
def get_cache_size(conv_module, input_len, pad_len, dim=0):
|
||||
dilated_kernerl_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1
|
||||
output_len = (input_len + pad_len - dilated_kernerl_size) // conv_module.stride[dim] + 1
|
||||
dilated_kernel_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1
|
||||
output_len = (input_len + pad_len - dilated_kernel_size) // conv_module.stride[dim] + 1
|
||||
remain_len = (
|
||||
input_len + pad_len - ((output_len - 1) * conv_module.stride[dim] + dilated_kernerl_size)
|
||||
input_len + pad_len - ((output_len - 1) * conv_module.stride[dim] + dilated_kernel_size)
|
||||
)
|
||||
overlap_len = dilated_kernerl_size - conv_module.stride[dim]
|
||||
cache_len = overlap_len + remain_len # >= 0
|
||||
overlap_len = dilated_kernel_size - conv_module.stride[dim]
|
||||
cache_len = overlap_len + remain_len
|
||||
|
||||
assert output_len > 0
|
||||
if output_len <= 0:
|
||||
raise ValueError(
|
||||
f"SeedVR2 VAE cache input is too short for convolution: input_len={input_len}, pad_len={pad_len}."
|
||||
)
|
||||
return cache_len
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
@ -294,52 +296,27 @@ class SpatialNorm(nn.Module):
|
||||
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
||||
return new_f
|
||||
|
||||
# partial implementation of diffusers's Attention for comfyui
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
heads: int = 8,
|
||||
kv_heads: Optional[int] = None,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = False,
|
||||
upcast_softmax: bool = False,
|
||||
norm_num_groups: Optional[int] = None,
|
||||
spatial_norm_dim: Optional[int] = None,
|
||||
out_bias: bool = True,
|
||||
scale_qk: bool = True,
|
||||
only_cross_attention: bool = False,
|
||||
eps: float = 1e-5,
|
||||
rescale_output_factor: float = 1.0,
|
||||
residual_connection: bool = False,
|
||||
out_dim: int = None,
|
||||
pre_only=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
|
||||
self.query_dim = query_dim
|
||||
self.use_bias = bias
|
||||
self.is_cross_attention = cross_attention_dim is not None
|
||||
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||
self.upcast_softmax = upcast_softmax
|
||||
self.inner_dim = dim_head * heads
|
||||
self.rescale_output_factor = rescale_output_factor
|
||||
self.residual_connection = residual_connection
|
||||
self.dropout = dropout
|
||||
self.fused_projections = False
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
self.pre_only = pre_only
|
||||
|
||||
self.scale_qk = scale_qk
|
||||
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
||||
|
||||
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||
self.sliceable_head_dim = heads
|
||||
|
||||
self.only_cross_attention = only_cross_attention
|
||||
self.out_dim = query_dim
|
||||
self.heads = heads
|
||||
|
||||
if norm_num_groups is not None:
|
||||
self.group_norm = ops.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
|
||||
@ -351,37 +328,19 @@ class Attention(nn.Module):
|
||||
else:
|
||||
self.spatial_norm = None
|
||||
|
||||
self.norm_q = None
|
||||
self.norm_k = None
|
||||
|
||||
self.norm_cross = None
|
||||
self.to_q = ops.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
|
||||
if not self.only_cross_attention:
|
||||
# only relevant for the `AddedKVProcessor` classes
|
||||
self.to_k = ops.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
||||
self.to_v = ops.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
||||
else:
|
||||
self.to_k = None
|
||||
self.to_v = None
|
||||
|
||||
if not self.pre_only:
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(ops.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||
self.to_out.append(nn.Dropout(dropout))
|
||||
else:
|
||||
self.to_out = None
|
||||
self.to_k = ops.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_v = ops.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(ops.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||
self.to_out.append(nn.Identity())
|
||||
|
||||
self.optimized_vae_attention = vae_attention()
|
||||
|
||||
def __call__(
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
residual = hidden_states
|
||||
@ -394,20 +353,14 @@ class Attention(nn.Module):
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
if self.group_norm is not None:
|
||||
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = self.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
key = self.to_k(encoder_hidden_states)
|
||||
value = self.to_v(encoder_hidden_states)
|
||||
key = self.to_k(hidden_states)
|
||||
value = self.to_v(hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // self.heads
|
||||
@ -417,25 +370,18 @@ class Attention(nn.Module):
|
||||
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if self.norm_q is not None:
|
||||
query = self.norm_q(query)
|
||||
if self.norm_k is not None:
|
||||
key = self.norm_k(key)
|
||||
|
||||
if input_ndim == 4 and encoder_hidden_states is hidden_states and attention_mask is None and self.heads == 1:
|
||||
if input_ndim == 4 and self.heads == 1:
|
||||
query = query.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width)
|
||||
key = key.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width)
|
||||
value = value.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width)
|
||||
hidden_states = self.optimized_vae_attention(query, key, value).reshape(batch_size, self.heads, head_dim, height * width).transpose(2, 3)
|
||||
else:
|
||||
hidden_states = optimized_attention(query, key, value, heads = self.heads, mask = attention_mask, skip_reshape=True, skip_output_reshape=True)
|
||||
hidden_states = optimized_attention(query, key, value, heads = self.heads, skip_reshape=True, skip_output_reshape=True)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = self.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
@ -471,7 +417,10 @@ def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor:
|
||||
memory_occupy = x.numel() * x.element_size() / 1024**3
|
||||
if isinstance(norm_layer, ops.GroupNorm) and memory_occupy > get_norm_limit():
|
||||
num_chunks = min(BYTEDANCE_GN_CHUNKS_FP16 if x.element_size() == 2 else BYTEDANCE_GN_CHUNKS_FP32, norm_layer.num_groups)
|
||||
assert norm_layer.num_groups % num_chunks == 0
|
||||
if norm_layer.num_groups % num_chunks != 0:
|
||||
raise ValueError(
|
||||
f"SeedVR2 VAE GroupNorm groups must divide chunks: groups={norm_layer.num_groups}, chunks={num_chunks}."
|
||||
)
|
||||
num_groups_per_chunk = norm_layer.num_groups // num_chunks
|
||||
|
||||
x = list(x.chunk(num_chunks, dim=1))
|
||||
@ -485,14 +434,15 @@ def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor:
|
||||
x = norm_layer(x)
|
||||
x = x.reshape((b, t, x.size(1), x.size(2), x.size(3))).transpose(1, 2)
|
||||
return x.to(input_dtype)
|
||||
raise NotImplementedError
|
||||
raise TypeError(f"SeedVR2 VAE unsupported norm layer type: {type(norm_layer).__name__}")
|
||||
|
||||
_receptive_field_t = Literal["half", "full"]
|
||||
|
||||
def extend_head(tensor, times: int = 2, memory = None):
|
||||
if memory is not None:
|
||||
return torch.cat((memory.to(tensor), tensor), dim=2)
|
||||
assert times >= 0, "Invalid input for function 'extend_head'!"
|
||||
if times < 0:
|
||||
raise ValueError(f"SeedVR2 VAE extend_head expected times >= 0, got {times}.")
|
||||
if times == 0:
|
||||
return tensor
|
||||
else:
|
||||
@ -547,13 +497,11 @@ class InflatedCausalConv3d(ops.Conv3d):
|
||||
padding=(0, 0, 0, 0, 0, 0),
|
||||
prev_cache=None,
|
||||
):
|
||||
# Compatible with no limit.
|
||||
if math.isinf(self.memory_limit):
|
||||
if prev_cache is not None:
|
||||
x = torch.cat([prev_cache, x], dim=split_dim - 1)
|
||||
return super().forward(x)
|
||||
|
||||
# Compute tensor shape after concat & padding.
|
||||
shape = list(x.size())
|
||||
if prev_cache is not None:
|
||||
shape[split_dim - 1] += prev_cache.size(split_dim - 1)
|
||||
@ -597,16 +545,19 @@ class InflatedCausalConv3d(ops.Conv3d):
|
||||
|
||||
next_cache = None
|
||||
cache_len = cache.size(split_dim) if cache is not None else 0
|
||||
next_catch_size = get_cache_size(
|
||||
next_cache_size = get_cache_size(
|
||||
conv_module=self,
|
||||
input_len=x[idx].size(split_dim) + cache_len,
|
||||
pad_len=pad_len,
|
||||
dim=split_dim - 2,
|
||||
)
|
||||
if next_catch_size != 0:
|
||||
assert next_catch_size <= x[idx].size(split_dim)
|
||||
if next_cache_size != 0:
|
||||
if next_cache_size > x[idx].size(split_dim):
|
||||
raise ValueError(
|
||||
f"SeedVR2 VAE cache size {next_cache_size} exceeds split size {x[idx].size(split_dim)}."
|
||||
)
|
||||
next_cache = (
|
||||
x[idx].transpose(0, split_dim)[-next_catch_size:].transpose(0, split_dim)
|
||||
x[idx].transpose(0, split_dim)[-next_cache_size:].transpose(0, split_dim)
|
||||
)
|
||||
|
||||
x[idx] = self.memory_limit_conv(
|
||||
@ -627,7 +578,8 @@ class InflatedCausalConv3d(ops.Conv3d):
|
||||
memory_state: MemoryState = MemoryState.UNSET,
|
||||
memory_cache = None,
|
||||
) -> Tensor:
|
||||
assert memory_state != MemoryState.UNSET
|
||||
if memory_state == MemoryState.UNSET:
|
||||
raise ValueError("SeedVR2 VAE convolution requires an explicit MemoryState.")
|
||||
if memory_cache is None:
|
||||
memory_cache = {}
|
||||
if memory_state != MemoryState.ACTIVE:
|
||||
@ -677,9 +629,8 @@ class InflatedCausalConv3d(ops.Conv3d):
|
||||
input, cache_size=cache_size, memory=memory, times=self.temporal_padding * 2
|
||||
)
|
||||
|
||||
# Single GPU inference - simplified memory management
|
||||
if (
|
||||
memory_state in [MemoryState.INITIALIZING, MemoryState.ACTIVE] # use_slicing
|
||||
memory_state in [MemoryState.INITIALIZING, MemoryState.ACTIVE]
|
||||
and cache_size != 0
|
||||
):
|
||||
if cache_size > input[-1].size(2) and cache is not None and len(input) == 1:
|
||||
@ -690,7 +641,6 @@ class InflatedCausalConv3d(ops.Conv3d):
|
||||
|
||||
padding = tuple(x for x in reversed(self.padding) for _ in range(2))
|
||||
for i in range(len(input)):
|
||||
# Prepare cache for next input slice.
|
||||
next_cache = None
|
||||
cache_size = 0
|
||||
if i < len(input) - 1:
|
||||
@ -700,17 +650,16 @@ class InflatedCausalConv3d(ops.Conv3d):
|
||||
if cache_size > input[i].size(2) and cache is not None:
|
||||
input[i] = torch.cat([cache, input[i]], dim=2)
|
||||
cache = None
|
||||
assert cache_size <= input[i].size(2), f"{cache_size} > {input[i].size(2)}"
|
||||
if cache_size > input[i].size(2):
|
||||
raise ValueError(f"SeedVR2 VAE cache size {cache_size} exceeds input length {input[i].size(2)}.")
|
||||
next_cache = input[i][:, :, -cache_size:]
|
||||
|
||||
# Conv forward for this input slice.
|
||||
input[i] = self.memory_limit_conv(
|
||||
input[i],
|
||||
padding=padding,
|
||||
prev_cache=cache
|
||||
)
|
||||
|
||||
# Update cache.
|
||||
cache = next_cache
|
||||
|
||||
return input[0] if squeeze_out else input
|
||||
@ -729,7 +678,6 @@ class Upsample3D(nn.Module):
|
||||
inflation_mode = "tail",
|
||||
temporal_up: bool = False,
|
||||
spatial_up: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
@ -760,9 +708,9 @@ class Upsample3D(nn.Module):
|
||||
hidden_states: torch.FloatTensor,
|
||||
memory_state=None,
|
||||
memory_cache=None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
if hidden_states.shape[1] != self.channels:
|
||||
raise ValueError(f"SeedVR2 upsample expected {self.channels} channels, got {hidden_states.shape[1]}.")
|
||||
|
||||
hidden_states = self.upscale_conv(hidden_states)
|
||||
b, channels, f, h, w = hidden_states.shape
|
||||
@ -785,8 +733,6 @@ class Upsample3D(nn.Module):
|
||||
|
||||
|
||||
class Downsample3D(nn.Module):
|
||||
"""A 3D downsampling layer with an optional convolution."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
@ -794,7 +740,6 @@ class Downsample3D(nn.Module):
|
||||
inflation_mode = "tail",
|
||||
spatial_down: bool = False,
|
||||
temporal_down: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
@ -823,20 +768,17 @@ class Downsample3D(nn.Module):
|
||||
hidden_states: torch.FloatTensor,
|
||||
memory_state = None,
|
||||
memory_cache = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if hasattr(self, "norm") and self.norm is not None:
|
||||
# [Overridden] change to causal norm.
|
||||
hidden_states = causal_norm_wrapper(self.norm, hidden_states)
|
||||
if hidden_states.shape[1] != self.channels:
|
||||
raise ValueError(f"SeedVR2 downsample expected {self.channels} channels, got {hidden_states.shape[1]}.")
|
||||
|
||||
if self.spatial_down:
|
||||
pad = (0, 1, 0, 1)
|
||||
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
|
||||
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
if hidden_states.shape[1] != self.channels:
|
||||
raise ValueError(f"SeedVR2 downsample expected {self.channels} channels after padding, got {hidden_states.shape[1]}.")
|
||||
|
||||
hidden_states = self.conv(hidden_states, memory_state=memory_state, memory_cache=memory_cache)
|
||||
|
||||
@ -848,7 +790,6 @@ class ResnetBlock3D(nn.Module):
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
temb_channels: int = 512,
|
||||
groups: int = 32,
|
||||
groups_out: Optional[int] = None,
|
||||
@ -857,7 +798,6 @@ class ResnetBlock3D(nn.Module):
|
||||
skip_time_act: bool = False,
|
||||
inflation_mode = "tail",
|
||||
time_receptive_field: _receptive_field_t = "half",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
@ -866,15 +806,14 @@ class ResnetBlock3D(nn.Module):
|
||||
self.skip_time_act = skip_time_act
|
||||
self.nonlinearity = nn.SiLU()
|
||||
if temb_channels is not None:
|
||||
self.time_emb_proj = ops.Linear(temb_channels, out_channels)
|
||||
self.time_emb_proj = ops.Linear(temb_channels, self.out_channels)
|
||||
else:
|
||||
self.time_emb_proj = None
|
||||
self.norm1 = ops.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
self.norm2 = ops.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
||||
self.use_in_shortcut = self.in_channels != out_channels
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.norm2 = ops.GroupNorm(num_groups=groups_out, num_channels=self.out_channels, eps=eps, affine=True)
|
||||
self.use_in_shortcut = self.in_channels != self.out_channels
|
||||
self.conv1 = InflatedCausalConv3d(
|
||||
self.in_channels,
|
||||
self.out_channels,
|
||||
@ -886,7 +825,7 @@ class ResnetBlock3D(nn.Module):
|
||||
|
||||
self.conv2 = InflatedCausalConv3d(
|
||||
self.out_channels,
|
||||
out_channels,
|
||||
self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
@ -897,7 +836,7 @@ class ResnetBlock3D(nn.Module):
|
||||
if self.use_in_shortcut:
|
||||
self.conv_shortcut = InflatedCausalConv3d(
|
||||
self.in_channels,
|
||||
out_channels,
|
||||
self.out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
@ -905,9 +844,7 @@ class ResnetBlock3D(nn.Module):
|
||||
inflation_mode=inflation_mode,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input_tensor, temb, memory_state = None, memory_cache = None, **kwargs
|
||||
):
|
||||
def forward(self, input_tensor, temb, memory_state = None, memory_cache = None):
|
||||
hidden_states = input_tensor
|
||||
|
||||
hidden_states = causal_norm_wrapper(self.norm1, hidden_states)
|
||||
@ -928,7 +865,6 @@ class ResnetBlock3D(nn.Module):
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states, memory_state=memory_state, memory_cache=memory_cache)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
@ -944,7 +880,6 @@ class DownEncoderBlock3D(nn.Module):
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_groups: int = 32,
|
||||
@ -957,28 +892,23 @@ class DownEncoderBlock3D(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
temporal_modules = []
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
# [Override] Replace module.
|
||||
ResnetBlock3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=None,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
output_scale_factor=output_scale_factor,
|
||||
inflation_mode=inflation_mode,
|
||||
time_receptive_field=time_receptive_field,
|
||||
)
|
||||
)
|
||||
temporal_modules.append(nn.Identity())
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.temporal_modules = nn.ModuleList(temporal_modules)
|
||||
|
||||
if add_downsample:
|
||||
self.downsamplers = nn.ModuleList(
|
||||
@ -1000,11 +930,9 @@ class DownEncoderBlock3D(nn.Module):
|
||||
hidden_states: torch.FloatTensor,
|
||||
memory_state = None,
|
||||
memory_cache = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
for resnet, temporal in zip(self.resnets, self.temporal_modules):
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state, memory_cache=memory_cache)
|
||||
hidden_states = temporal(hidden_states)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
@ -1018,7 +946,6 @@ class UpDecoderBlock3D(nn.Module):
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_groups: int = 32,
|
||||
@ -1032,33 +959,26 @@ class UpDecoderBlock3D(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
temporal_modules = []
|
||||
|
||||
for i in range(num_layers):
|
||||
input_channels = in_channels if i == 0 else out_channels
|
||||
|
||||
resnets.append(
|
||||
# [Override] Replace module.
|
||||
ResnetBlock3D(
|
||||
in_channels=input_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
output_scale_factor=output_scale_factor,
|
||||
inflation_mode=inflation_mode,
|
||||
time_receptive_field=time_receptive_field,
|
||||
)
|
||||
)
|
||||
|
||||
temporal_modules.append(nn.Identity())
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.temporal_modules = nn.ModuleList(temporal_modules)
|
||||
|
||||
if add_upsample:
|
||||
# [Override] Replace module & use learnable upsample
|
||||
self.upsamplers = nn.ModuleList(
|
||||
[
|
||||
Upsample3D(
|
||||
@ -1080,9 +1000,8 @@ class UpDecoderBlock3D(nn.Module):
|
||||
memory_state=None,
|
||||
memory_cache=None,
|
||||
) -> torch.FloatTensor:
|
||||
for resnet, temporal in zip(self.resnets, self.temporal_modules):
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state, memory_cache=memory_cache)
|
||||
hidden_states = temporal(hidden_states)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
@ -1096,7 +1015,6 @@ class UNetMidBlock3D(nn.Module):
|
||||
self,
|
||||
in_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default", # default, spatial
|
||||
@ -1111,16 +1029,13 @@ class UNetMidBlock3D(nn.Module):
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
self.add_attention = add_attention
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
# [Override] Replace module.
|
||||
ResnetBlock3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
output_scale_factor=output_scale_factor,
|
||||
inflation_mode=inflation_mode,
|
||||
time_receptive_field=time_receptive_field,
|
||||
@ -1148,7 +1063,6 @@ class UNetMidBlock3D(nn.Module):
|
||||
),
|
||||
residual_connection=True,
|
||||
bias=True,
|
||||
upcast_softmax=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@ -1161,7 +1075,6 @@ class UNetMidBlock3D(nn.Module):
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
output_scale_factor=output_scale_factor,
|
||||
inflation_mode=inflation_mode,
|
||||
time_receptive_field=time_receptive_field,
|
||||
@ -1172,7 +1085,7 @@ class UNetMidBlock3D(nn.Module):
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, temb=None, memory_state=None, memory_cache=None):
|
||||
video_length, frame_height, frame_width = hidden_states.size()[-3:]
|
||||
video_length = hidden_states.size(2)
|
||||
hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state, memory_cache=memory_cache)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if attn is not None:
|
||||
@ -1195,7 +1108,6 @@ class Encoder3D(nn.Module):
|
||||
layers_per_block: int = 2,
|
||||
norm_num_groups: int = 32,
|
||||
mid_block_add_attention=True,
|
||||
# [Override] add temporal down num
|
||||
temporal_down_num: int = 2,
|
||||
inflation_mode = "tail",
|
||||
time_receptive_field: _receptive_field_t = "half",
|
||||
@ -1216,17 +1128,15 @@ class Encoder3D(nn.Module):
|
||||
self.mid_block = None
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
# [Override] to support temporal down block design
|
||||
is_temporal_down_block = i >= len(block_out_channels) - self.temporal_down_num - 1
|
||||
# Note: take the last ones
|
||||
|
||||
assert down_block_type == "DownEncoderBlock3D"
|
||||
if down_block_type != "DownEncoderBlock3D":
|
||||
raise ValueError(f"SeedVR2 encoder only supports DownEncoderBlock3D, got {down_block_type}.")
|
||||
|
||||
down_block = DownEncoderBlock3D(
|
||||
num_layers=self.layers_per_block,
|
||||
@ -1242,7 +1152,6 @@ class Encoder3D(nn.Module):
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlock3D(
|
||||
in_channels=block_out_channels[-1],
|
||||
resnet_eps=1e-6,
|
||||
@ -1256,7 +1165,6 @@ class Encoder3D(nn.Module):
|
||||
time_receptive_field=time_receptive_field,
|
||||
)
|
||||
|
||||
# out
|
||||
self.conv_norm_out = ops.GroupNorm(
|
||||
num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6
|
||||
)
|
||||
@ -1274,17 +1182,13 @@ class Encoder3D(nn.Module):
|
||||
memory_state = None,
|
||||
memory_cache = None,
|
||||
) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Encoder` class."""
|
||||
sample = sample.to(next(self.parameters()).device)
|
||||
sample = self.conv_in(sample, memory_state=memory_state, memory_cache=memory_cache)
|
||||
# down
|
||||
for down_block in self.down_blocks:
|
||||
sample = down_block(sample, memory_state=memory_state, memory_cache=memory_cache)
|
||||
|
||||
# middle
|
||||
sample = self.mid_block(sample, memory_state=memory_state, memory_cache=memory_cache)
|
||||
|
||||
# post-process
|
||||
sample = causal_norm_wrapper(self.conv_norm_out, sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample, memory_state=memory_state, memory_cache=memory_cache)
|
||||
@ -1303,7 +1207,6 @@ class Decoder3D(nn.Module):
|
||||
layers_per_block: int = 2,
|
||||
norm_num_groups: int = 32,
|
||||
mid_block_add_attention=True,
|
||||
# [Override] add temporal up block
|
||||
inflation_mode = "tail",
|
||||
time_receptive_field: _receptive_field_t = "half",
|
||||
temporal_up_num: int = 2,
|
||||
@ -1326,7 +1229,6 @@ class Decoder3D(nn.Module):
|
||||
|
||||
temb_channels = None
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlock3D(
|
||||
in_channels=block_out_channels[-1],
|
||||
resnet_eps=1e-6,
|
||||
@ -1340,7 +1242,6 @@ class Decoder3D(nn.Module):
|
||||
time_receptive_field=time_receptive_field,
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
@ -1349,7 +1250,8 @@ class Decoder3D(nn.Module):
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
is_temporal_up_block = i < self.temporal_up_num
|
||||
assert up_block_type == "UpDecoderBlock3D"
|
||||
if up_block_type != "UpDecoderBlock3D":
|
||||
raise ValueError(f"SeedVR2 decoder only supports UpDecoderBlock3D, got {up_block_type}.")
|
||||
up_block = UpDecoderBlock3D(
|
||||
num_layers=self.layers_per_block + 1,
|
||||
in_channels=prev_output_channel,
|
||||
@ -1365,7 +1267,6 @@ class Decoder3D(nn.Module):
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
self.conv_norm_out = ops.GroupNorm(
|
||||
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
|
||||
)
|
||||
@ -1375,7 +1276,6 @@ class Decoder3D(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
# Note: Just copy from Decoder.
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
@ -1388,15 +1288,12 @@ class Decoder3D(nn.Module):
|
||||
sample = self.conv_in(sample, memory_state=memory_state, memory_cache=memory_cache)
|
||||
|
||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
||||
# middle
|
||||
sample = self.mid_block(sample, latent_embeds, memory_state=memory_state, memory_cache=memory_cache)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
sample = up_block(sample, latent_embeds, memory_state=memory_state, memory_cache=memory_cache)
|
||||
|
||||
# post-process
|
||||
sample = causal_norm_wrapper(self.conv_norm_out, sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample, memory_state=memory_state, memory_cache=memory_cache)
|
||||
@ -1415,8 +1312,6 @@ class VideoAutoencoderKL(nn.Module):
|
||||
inflation_mode = "pad",
|
||||
time_receptive_field: _receptive_field_t = "full",
|
||||
slicing_sample_min_size = BYTEDANCE_SLICING_SAMPLE_MIN,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
self.slicing_sample_min_size = slicing_sample_min_size
|
||||
self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num)
|
||||
@ -1425,7 +1320,6 @@ class VideoAutoencoderKL(nn.Module):
|
||||
up_block_types = ("UpDecoderBlock3D",) * 4
|
||||
super().__init__()
|
||||
|
||||
# pass init params to Encoder
|
||||
self.encoder = Encoder3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=latent_channels,
|
||||
@ -1433,13 +1327,11 @@ class VideoAutoencoderKL(nn.Module):
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
norm_num_groups=norm_num_groups,
|
||||
# [Override] add temporal_down_num parameter
|
||||
temporal_down_num=temporal_scale_num,
|
||||
inflation_mode=inflation_mode,
|
||||
time_receptive_field=time_receptive_field,
|
||||
)
|
||||
|
||||
# pass init params to Decoder
|
||||
self.decoder = Decoder3D(
|
||||
in_channels=latent_channels,
|
||||
out_channels=out_channels,
|
||||
@ -1447,7 +1339,6 @@ class VideoAutoencoderKL(nn.Module):
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
norm_num_groups=norm_num_groups,
|
||||
# [Override] add temporal_up_num parameter
|
||||
temporal_up_num=temporal_scale_num,
|
||||
inflation_mode=inflation_mode,
|
||||
time_receptive_field=time_receptive_field,
|
||||
@ -1489,11 +1380,10 @@ class VideoAutoencoderKL(nn.Module):
|
||||
return output.to(z.device)
|
||||
|
||||
def slicing_encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
sp_size =1
|
||||
if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size:
|
||||
if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size:
|
||||
memory_cache = {}
|
||||
split_size = max(
|
||||
self.slicing_sample_min_size * sp_size,
|
||||
self.slicing_sample_min_size,
|
||||
getattr(self, "temporal_downsample_factor", 1),
|
||||
)
|
||||
x_slices = list(x[:, :, 1:].split(split_size=split_size, dim=2))
|
||||
@ -1518,10 +1408,9 @@ class VideoAutoencoderKL(nn.Module):
|
||||
return self._encode(x)
|
||||
|
||||
def slicing_decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
sp_size = 1
|
||||
if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size:
|
||||
if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size:
|
||||
memory_cache = {}
|
||||
z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2)
|
||||
z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size, dim=2)
|
||||
decoded_slices = [
|
||||
self._decode(
|
||||
torch.cat((z[:, :, :1], z_slices[0]), dim=2),
|
||||
@ -1538,33 +1427,28 @@ class VideoAutoencoderKL(nn.Module):
|
||||
else:
|
||||
return self._decode(z)
|
||||
|
||||
def forward(
|
||||
self, x: torch.FloatTensor, mode: Literal["encode", "decode", "all"] = "all", **kwargs
|
||||
):
|
||||
# x: [b c t h w]
|
||||
def forward(self, x: torch.FloatTensor, mode: Literal["encode", "decode", "all"] = "all"):
|
||||
def _unwrap(value):
|
||||
return value[0] if isinstance(value, tuple) else value
|
||||
|
||||
if mode == "encode":
|
||||
return _unwrap(self.encode(x))
|
||||
elif mode == "decode":
|
||||
if mode == "decode":
|
||||
return _unwrap(self.decode_(x))
|
||||
else:
|
||||
if mode == "all":
|
||||
latent = _unwrap(self.encode(x))
|
||||
return _unwrap(self.decode_(latent))
|
||||
raise ValueError(f"Unknown SeedVR2 VAE forward mode: {mode}")
|
||||
|
||||
class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
spatial_downsample_factor = 8,
|
||||
temporal_downsample_factor = 4,
|
||||
**kwargs,
|
||||
):
|
||||
self.spatial_downsample_factor = spatial_downsample_factor
|
||||
self.temporal_downsample_factor = temporal_downsample_factor
|
||||
self.enable_tiling = False
|
||||
super().__init__(*args, **kwargs)
|
||||
super().__init__()
|
||||
self.set_memory_limit(BYTEDANCE_VAE_CONV_MEM_GIB, BYTEDANCE_VAE_NORM_MEM_GIB)
|
||||
|
||||
def forward(self, x: torch.FloatTensor):
|
||||
@ -1581,7 +1465,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
||||
z = p.squeeze(2)
|
||||
return z, p
|
||||
|
||||
def encode(self, x, orig_dims=None):
|
||||
def encode(self, x):
|
||||
z, _ = self._encode_with_raw_latent(x)
|
||||
return z
|
||||
|
||||
@ -1594,26 +1478,27 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
||||
)
|
||||
|
||||
if z.ndim == 5:
|
||||
b, c, t_latent, h, w = z.shape
|
||||
if c != 16:
|
||||
_, c, _, _, _ = z.shape
|
||||
if c != SEEDVR2_LATENT_CHANNELS:
|
||||
raise RuntimeError(
|
||||
"SeedVR2 VideoAutoencoderKLWrapper.decode: 5-D latent input must "
|
||||
f"have 16 channels; got shape {tuple(z.shape)}."
|
||||
f"have {SEEDVR2_LATENT_CHANNELS} channels; got shape {tuple(z.shape)}."
|
||||
)
|
||||
latent = z
|
||||
elif z.ndim == 4:
|
||||
b, tc, h, w = z.shape
|
||||
if tc % 16 != 0:
|
||||
if tc % SEEDVR2_LATENT_CHANNELS != 0:
|
||||
raise RuntimeError(
|
||||
"SeedVR2 VideoAutoencoderKLWrapper.decode: 4-D latent input must "
|
||||
"use collapsed channel layout (B, 16*T, H, W); "
|
||||
f"use collapsed channel layout (B, {SEEDVR2_LATENT_CHANNELS}*T, H, W); "
|
||||
f"got shape {tuple(z.shape)}."
|
||||
)
|
||||
latent = z.reshape(b, 16, -1, h, w)
|
||||
latent = z.reshape(b, SEEDVR2_LATENT_CHANNELS, -1, h, w)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"SeedVR2 VideoAutoencoderKLWrapper.decode: latent input must be "
|
||||
"4-D collapsed (B, 16*T, H, W) or 5-D (B, 16, T, H, W); "
|
||||
f"4-D collapsed (B, {SEEDVR2_LATENT_CHANNELS}*T, H, W) or "
|
||||
f"5-D (B, {SEEDVR2_LATENT_CHANNELS}, T, H, W); "
|
||||
f"got shape {tuple(z.shape)}."
|
||||
)
|
||||
scale = BYTEDANCE_VAE_SCALING_FACTOR
|
||||
@ -1621,10 +1506,11 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
||||
latent = latent / scale + shift
|
||||
|
||||
self.device = latent.device
|
||||
self.enable_tiling = seedvr2_tiling.get("enable_tiling", False)
|
||||
enable_tiling = seedvr2_tiling.get("enable_tiling", False)
|
||||
|
||||
if self.enable_tiling:
|
||||
if enable_tiling:
|
||||
decode_seedvr2_args = dict(seedvr2_tiling)
|
||||
decode_seedvr2_args.pop("enable_tiling", None)
|
||||
tile_h, tile_w = decode_seedvr2_args.get("tile_size", (512, 512))
|
||||
ov_h, ov_w = decode_seedvr2_args.get("tile_overlap", (64, 64))
|
||||
decode_seedvr2_args["tile_overlap"] = (
|
||||
@ -1641,7 +1527,6 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
||||
else:
|
||||
x = super().decode_(latent)
|
||||
|
||||
# ensure even dims for save video
|
||||
h, w = x.shape[-2:]
|
||||
w2 = w - (w % 2)
|
||||
h2 = h - (h % 2)
|
||||
@ -1693,7 +1578,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
||||
if samples.ndim == 4:
|
||||
samples = samples.unsqueeze(2)
|
||||
samples = samples.contiguous()
|
||||
samples = samples * 0.9152
|
||||
samples = samples * BYTEDANCE_VAE_SCALING_FACTOR
|
||||
return samples
|
||||
|
||||
def comfy_memory_used_decode(self, shape):
|
||||
@ -1707,15 +1592,15 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
||||
# plus int64 sort indices dominate peak memory, not the VAE weight dtype.
|
||||
if len(shape) == 5:
|
||||
candidates = []
|
||||
if shape[1] == 16:
|
||||
if shape[1] == SEEDVR2_LATENT_CHANNELS:
|
||||
candidates.append((shape[2], shape[3], shape[4]))
|
||||
if shape[-1] == 16:
|
||||
if shape[-1] == SEEDVR2_LATENT_CHANNELS:
|
||||
candidates.append((shape[1], shape[2], shape[3]))
|
||||
if len(candidates) == 0:
|
||||
candidates.append((shape[2], shape[3], shape[4]))
|
||||
pixels = max(output_pixels(*candidate) for candidate in candidates)
|
||||
elif len(shape) == 4:
|
||||
latent_t = max(1, (shape[1] + 15) // 16)
|
||||
latent_t = max(1, (shape[1] + SEEDVR2_LATENT_CHANNELS - 1) // SEEDVR2_LATENT_CHANNELS)
|
||||
pixels = output_pixels(latent_t, shape[2], shape[3])
|
||||
else:
|
||||
pixels = output_pixels(1, shape[-2], shape[-1])
|
||||
|
||||
@ -933,7 +933,8 @@ class HunyuanDiT(BaseModel):
|
||||
|
||||
class SeedVR2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT)
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.seedvr.model.NaDiT)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
condition = kwargs.get("condition", None)
|
||||
|
||||
@ -598,43 +598,34 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
return dit_config
|
||||
|
||||
if "{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix) in state_dict_keys and state_dict["{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix)].shape[1] == 3072: # seedvr2 7b
|
||||
seedvr2_7b_separate_key = "{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix)
|
||||
if seedvr2_7b_separate_key in state_dict_keys and state_dict[seedvr2_7b_separate_key].shape[1] == 3072: # seedvr2 7b
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "seedvr2"
|
||||
dit_config["vid_dim"] = 3072
|
||||
dit_config["heads"] = 24
|
||||
dit_config["num_layers"] = 36
|
||||
# 7B uses non-shared MMModule layout (separate ``vid.`` / ``txt.``
|
||||
# submodules) at EVERY block — verified by inspecting the 7B
|
||||
# state_dict at ``blocks.31.ada.txt.attn_gate`` (txt. prefix means
|
||||
# ``MMModule.shared_weights=False``). Native NaDiT computes
|
||||
# per-block ``shared_weights = not (i < mm_layers)``, so to keep
|
||||
# every block non-shared we set ``mm_layers = num_layers``.
|
||||
# Without this, blocks at index >= mm_layers (default 10) try to
|
||||
# load ``blocks.N.*.all.*`` keys that don't exist in the file,
|
||||
# silently miss-load → all-black output.
|
||||
# This checkpoint uses separate vid/txt MMModule keys in every block.
|
||||
dit_config["mm_layers"] = 36
|
||||
dit_config["norm_eps"] = 1e-5
|
||||
dit_config["rope_type"] = "rope3d"
|
||||
dit_config["rope_dim"] = 64
|
||||
dit_config["mlp_type"] = "normal"
|
||||
return dit_config
|
||||
elif "{}blocks.35.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b
|
||||
if "{}blocks.35.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "seedvr2"
|
||||
dit_config["vid_dim"] = 3072
|
||||
dit_config["heads"] = 24
|
||||
dit_config["num_layers"] = 36
|
||||
# This checkpoint layout carries shared ``all.`` MMModule keys.
|
||||
# Preserve the historical split: the initial blocks use separate
|
||||
# vid/txt modules, later blocks use shared modules.
|
||||
# This checkpoint uses shared all.* MMModule keys after the initial blocks.
|
||||
dit_config["mm_layers"] = 10
|
||||
dit_config["norm_eps"] = 1e-5
|
||||
dit_config["rope_type"] = "rope3d"
|
||||
dit_config["rope_dim"] = 64
|
||||
dit_config["mlp_type"] = "swiglu"
|
||||
return dit_config
|
||||
elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b
|
||||
if "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "seedvr2"
|
||||
dit_config["vid_dim"] = 2560
|
||||
@ -1150,8 +1141,22 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
unet_config["heatmap_head"] = True
|
||||
|
||||
return unet_config
|
||||
def normalize_seedvr2_unet_config(unet_config):
|
||||
if unet_config.get("image_model") != "seedvr2" or "num_heads" not in unet_config:
|
||||
return unet_config
|
||||
|
||||
unet_config = dict(unet_config)
|
||||
num_heads = unet_config.pop("num_heads")
|
||||
if "heads" in unet_config and unet_config["heads"] != num_heads:
|
||||
raise ValueError(
|
||||
f"SeedVR2 config has conflicting heads={unet_config['heads']} and num_heads={num_heads}."
|
||||
)
|
||||
unet_config["heads"] = num_heads
|
||||
return unet_config
|
||||
|
||||
|
||||
def model_config_from_unet_config(unet_config, state_dict=None, unet_key_prefix=""):
|
||||
unet_config = normalize_seedvr2_unet_config(unet_config)
|
||||
for model_config in comfy.supported_models.models:
|
||||
if model_config.matches(unet_config, state_dict, unet_key_prefix=unet_key_prefix):
|
||||
return model_config(unet_config)
|
||||
|
||||
46
comfy/sd.py
46
comfy/sd.py
@ -472,8 +472,7 @@ class VAE:
|
||||
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
|
||||
is_seedvr2_vae = "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd
|
||||
if not is_seedvr2_vae and 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||
if metadata is None or metadata.get("keep_diffusers_format") != "true":
|
||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||
|
||||
if model_management.is_amd():
|
||||
VAE_KL_MEM_RATIO = 2.73
|
||||
@ -549,7 +548,7 @@ class VAE:
|
||||
self.latent_channels = 16
|
||||
elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2
|
||||
self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper()
|
||||
self.latent_channels = 16
|
||||
self.latent_channels = comfy.ldm.seedvr.vae.SEEDVR2_LATENT_CHANNELS
|
||||
self.latent_dim = 3
|
||||
self.disable_offload = True
|
||||
self.memory_used_decode = lambda shape, dtype: self.first_stage_model.comfy_memory_used_decode(shape)
|
||||
@ -1074,6 +1073,20 @@ class VAE:
|
||||
out = self.first_stage_model.encode_tiled(x, **kwargs)
|
||||
return out.to(device=self.output_device, dtype=self.vae_output_dtype())
|
||||
|
||||
def _owned_tiled_args(self, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||
args = {}
|
||||
if tile_x is not None:
|
||||
args["tile_x"] = tile_x
|
||||
if tile_y is not None:
|
||||
args["tile_y"] = tile_y
|
||||
if overlap is not None:
|
||||
args["overlap"] = overlap
|
||||
if tile_t is not None:
|
||||
args["tile_t"] = tile_t
|
||||
if overlap_t is not None:
|
||||
args["overlap_t"] = overlap_t
|
||||
return args
|
||||
|
||||
def decode(self, samples_in, vae_options={}):
|
||||
self.throw_exception_if_invalid()
|
||||
pixel_samples = None
|
||||
@ -1153,18 +1166,7 @@ class VAE:
|
||||
|
||||
with model_management.cuda_device_context(self.device):
|
||||
if self.handles_tiling and dims in (2, 3):
|
||||
tiled_args = {}
|
||||
if tile_x is not None:
|
||||
tiled_args["tile_x"] = tile_x
|
||||
if tile_y is not None:
|
||||
tiled_args["tile_y"] = tile_y
|
||||
if overlap is not None:
|
||||
tiled_args["overlap"] = overlap
|
||||
if tile_t is not None:
|
||||
tiled_args["tile_t"] = tile_t
|
||||
if overlap_t is not None:
|
||||
tiled_args["overlap_t"] = overlap_t
|
||||
output = self._decode_tiled_owned(samples, **tiled_args)
|
||||
output = self._decode_tiled_owned(samples, **self._owned_tiled_args(tile_x, tile_y, overlap, tile_t, overlap_t))
|
||||
elif dims == 1 or self.extra_1d_channel is not None:
|
||||
args.pop("tile_y")
|
||||
output = self.decode_tiled_1d(samples, **args)
|
||||
@ -1269,18 +1271,7 @@ class VAE:
|
||||
samples = self.encode_tiled_(pixel_samples, **args)
|
||||
elif dims == 3:
|
||||
if self.handles_tiling:
|
||||
tiled_args = {}
|
||||
if tile_x is not None:
|
||||
tiled_args["tile_x"] = tile_x
|
||||
if tile_y is not None:
|
||||
tiled_args["tile_y"] = tile_y
|
||||
if overlap is not None:
|
||||
tiled_args["overlap"] = overlap
|
||||
if tile_t is not None:
|
||||
tiled_args["tile_t"] = tile_t
|
||||
if overlap_t is not None:
|
||||
tiled_args["overlap_t"] = overlap_t
|
||||
samples = self._encode_tiled_owned(pixel_samples, **tiled_args)
|
||||
samples = self._encode_tiled_owned(pixel_samples, **self._owned_tiled_args(tile_x, tile_y, overlap, tile_t, overlap_t))
|
||||
else:
|
||||
if tile_t is not None:
|
||||
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
|
||||
@ -1850,7 +1841,6 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
|
||||
return (model, clip, vae)
|
||||
|
||||
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
|
||||
|
||||
@ -1688,6 +1688,7 @@ class SeedVR2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "seedvr2"
|
||||
}
|
||||
unet_extra_config = {}
|
||||
required_keys = {
|
||||
"{}positive_conditioning",
|
||||
"{}negative_conditioning",
|
||||
|
||||
@ -19,21 +19,14 @@ from comfy.ldm.seedvr.constants import (
|
||||
)
|
||||
|
||||
from torchvision.transforms import functional as TVF
|
||||
from torchvision.transforms import Lambda
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
|
||||
_SEEDVR2_INVALID_MODEL_MSG_PREFIX = (
|
||||
"SeedVR2Conditioning: model object does not match expected SeedVR2 structure"
|
||||
)
|
||||
|
||||
# Private sentinel for getattr default: distinguishes "attribute missing"
|
||||
# from "attribute present but None" so the failure message is accurate.
|
||||
_SEEDVR2_INVALID_MODEL_MSG_PREFIX = "SeedVR2Conditioning: model object does not match expected SeedVR2 structure"
|
||||
_ATTR_MISSING = object()
|
||||
|
||||
|
||||
def _resolve_seedvr2_diffusion_model(model):
|
||||
"""Resolve ``model.model.diffusion_model``, failing loud via the ``_ATTR_MISSING`` sentinel so each of the four modes (model/diffusion_model missing vs None) gives an accurate message."""
|
||||
inner = getattr(model, "model", _ATTR_MISSING)
|
||||
if inner is _ATTR_MISSING:
|
||||
raise RuntimeError(
|
||||
@ -59,15 +52,7 @@ def _resolve_seedvr2_diffusion_model(model):
|
||||
return diffusion_model
|
||||
|
||||
|
||||
def get_conditions(latent, latent_blur):
|
||||
t, h, w, c = latent.shape
|
||||
cond = torch.ones([t, h, w, c + 1], device=latent.device, dtype=latent.dtype)
|
||||
cond[:, ..., :-1] = latent_blur[:]
|
||||
cond[:, ..., -1:] = 1.0
|
||||
return cond
|
||||
|
||||
def div_pad(image, factor):
|
||||
|
||||
height_factor, width_factor = factor
|
||||
height, width = image.shape[-2:]
|
||||
|
||||
@ -77,31 +62,25 @@ def div_pad(image, factor):
|
||||
if pad_height == 0 and pad_width == 0:
|
||||
return image
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
padding = (0, pad_width, 0, pad_height)
|
||||
image = torch.nn.functional.pad(image, padding, mode='constant', value=0.0)
|
||||
|
||||
return image
|
||||
padding = (0, pad_width, 0, pad_height)
|
||||
return torch.nn.functional.pad(image, padding, mode='constant', value=0.0)
|
||||
|
||||
def cut_videos(videos):
|
||||
t = videos.size(1)
|
||||
if t < 1:
|
||||
raise ValueError("SeedVR2Preprocess expected at least one frame.")
|
||||
if t == 1:
|
||||
return videos
|
||||
if t <= 4 :
|
||||
padding = [videos[:, -1].unsqueeze(1)] * (4 - t + 1)
|
||||
padding = torch.cat(padding, dim=1)
|
||||
videos = torch.cat([videos, padding], dim=1)
|
||||
return videos
|
||||
if (t - 1) % (4) == 0:
|
||||
return videos
|
||||
else:
|
||||
padding = [videos[:, -1].unsqueeze(1)] * (
|
||||
4 - ((t - 1) % (4))
|
||||
)
|
||||
padding = torch.cat(padding, dim=1)
|
||||
videos = torch.cat([videos, padding], dim=1)
|
||||
assert (videos.size(1) - 1) % (4) == 0
|
||||
if t <= 4:
|
||||
padding = videos[:, -1:].repeat(1, 4 - t + 1, 1, 1, 1)
|
||||
return torch.cat([videos, padding], dim=1)
|
||||
if (t - 1) % 4 == 0:
|
||||
return videos
|
||||
padding = videos[:, -1:].repeat(1, 4 - ((t - 1) % 4), 1, 1, 1)
|
||||
videos = torch.cat([videos, padding], dim=1)
|
||||
if (videos.size(1) - 1) % 4 != 0:
|
||||
raise ValueError(f"SeedVR2Preprocess failed to pad video length to 4n+1; got {videos.size(1)} frames.")
|
||||
return videos
|
||||
|
||||
def _seedvr2_input_shorter_edge(images, node_name):
|
||||
if images.dim() == 4:
|
||||
@ -136,8 +115,7 @@ def _seedvr2_pad(images, upscaled_shorter_edge, node_name):
|
||||
b, t, c, h, w = images.shape
|
||||
images = images.reshape(b * t, c, h, w)
|
||||
|
||||
clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0))
|
||||
images = clip(images)
|
||||
images = torch.clamp(images, 0.0, 1.0)
|
||||
images = div_pad(images, (16, 16))
|
||||
_, _, new_h, new_w = images.shape
|
||||
|
||||
@ -295,7 +273,6 @@ class SeedVR2PostProcessing(io.ComfyNode):
|
||||
def _color_transfer_chunked(cls, decoded_flat, reference_flat, output_device, color_correction_method):
|
||||
chunk_size = cls._estimate_color_correction_chunk_size(decoded_flat, color_correction_method)
|
||||
while True:
|
||||
next_chunk_size = None
|
||||
try:
|
||||
return cls._run_color_transfer_chunks(
|
||||
decoded_flat, reference_flat, output_device, color_correction_method, chunk_size,
|
||||
@ -307,9 +284,7 @@ class SeedVR2PostProcessing(io.ComfyNode):
|
||||
"SeedVR2PostProcessing: color correction OOM at one frame; "
|
||||
f"color_correction_method={color_correction_method}, shape={tuple(decoded_flat.shape)}."
|
||||
) from e
|
||||
next_chunk_size = max(1, chunk_size // SEEDVR2_OOM_BACKOFF_DIVISOR)
|
||||
|
||||
chunk_size = next_chunk_size
|
||||
chunk_size = max(1, chunk_size // SEEDVR2_OOM_BACKOFF_DIVISOR)
|
||||
|
||||
@classmethod
|
||||
def _run_color_transfer_chunks(cls, decoded_flat, reference_flat, output_device, color_correction_method, chunk_size):
|
||||
@ -392,10 +367,8 @@ class SeedVR2Conditioning(io.ComfyNode):
|
||||
io.Latent.Input("vae_conditioning", display_name="latent"),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(display_name="model", tooltip="The SeedVR2 model, passed through."),
|
||||
io.Conditioning.Output(display_name="positive", tooltip="The positive conditioning for sampling."),
|
||||
io.Conditioning.Output(display_name="negative", tooltip="The negative conditioning for sampling."),
|
||||
io.Latent.Output(display_name="latent", tooltip="The latent to denoise."),
|
||||
],
|
||||
)
|
||||
|
||||
@ -408,29 +381,30 @@ class SeedVR2Conditioning(io.ComfyNode):
|
||||
"SeedVR2Conditioning expects a 5-D VAE latent in Comfy "
|
||||
f"channel-first layout; got shape {tuple(vae_conditioning.shape)}."
|
||||
)
|
||||
if vae_conditioning.shape[-1] == SEEDVR2_LATENT_CHANNELS and vae_conditioning.shape[1] != SEEDVR2_LATENT_CHANNELS:
|
||||
if vae_conditioning.shape[1] != SEEDVR2_LATENT_CHANNELS:
|
||||
if vae_conditioning.shape[-1] == SEEDVR2_LATENT_CHANNELS:
|
||||
raise ValueError(
|
||||
"SeedVR2Conditioning expects SeedVR2 VAE latents in Comfy "
|
||||
f"channel-first layout (B, {SEEDVR2_LATENT_CHANNELS}, T, H, W); "
|
||||
f"got channel-last shape {tuple(vae_conditioning.shape)}."
|
||||
)
|
||||
raise ValueError(
|
||||
"SeedVR2Conditioning expects SeedVR2 VAE latents in Comfy "
|
||||
f"channel-first layout (B, {SEEDVR2_LATENT_CHANNELS}, T, H, W); "
|
||||
f"got channel-last shape {tuple(vae_conditioning.shape)}."
|
||||
"SeedVR2Conditioning expects SeedVR2 VAE latents with "
|
||||
f"{SEEDVR2_LATENT_CHANNELS} channels; got shape {tuple(vae_conditioning.shape)}."
|
||||
)
|
||||
vae_conditioning = vae_conditioning.movedim(1, -1).contiguous()
|
||||
model_patcher = model
|
||||
model = _resolve_seedvr2_diffusion_model(model_patcher)
|
||||
model = _resolve_seedvr2_diffusion_model(model)
|
||||
pos_cond = model.positive_conditioning
|
||||
neg_cond = model.negative_conditioning
|
||||
|
||||
condition = torch.stack([get_conditions(c, c) for c in vae_conditioning])
|
||||
mask = vae_conditioning.new_ones(vae_conditioning.shape[:-1] + (1,))
|
||||
condition = torch.cat((vae_conditioning, mask), dim=-1)
|
||||
condition = condition.movedim(-1, 1)
|
||||
latent = vae_conditioning.movedim(-1, 1)
|
||||
|
||||
latent = latent.reshape(latent.shape[0], latent.shape[1] * latent.shape[2], latent.shape[3], latent.shape[4])
|
||||
condition = condition.reshape(condition.shape[0], condition.shape[1] * condition.shape[2], condition.shape[3], condition.shape[4])
|
||||
|
||||
negative = [[neg_cond.unsqueeze(0), {"condition": condition}]]
|
||||
positive = [[pos_cond.unsqueeze(0), {"condition": condition}]]
|
||||
|
||||
return io.NodeOutput(model_patcher, positive, negative, {"samples": latent})
|
||||
return io.NodeOutput(positive, negative)
|
||||
|
||||
class SeedVRExtension(ComfyExtension):
|
||||
@override
|
||||
|
||||
@ -1,20 +1,15 @@
|
||||
"""Consolidated SeedVR2 conditioning and refactor regression tests.
|
||||
|
||||
Merges the prior test_seedvr2_refactor_nodes.py and
|
||||
test_seedvr_conditioning_hardening.py modules. Refactor tests use the
|
||||
top-level comfy_extras.nodes_seedvr import; conditioning-hardening tests
|
||||
use _import_nodes_seedvr_isolated() for sys.modules isolation when
|
||||
mocking comfy.model_management.
|
||||
"""
|
||||
"""SeedVR2 conditioning node regression tests."""
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.cli_args import args as cli_args
|
||||
from comfy.ldm.seedvr.constants import SEEDVR2_LATENT_CHANNELS
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
cli_args.cpu = True
|
||||
@ -79,21 +74,18 @@ def _import_nodes_seedvr_isolated():
|
||||
|
||||
|
||||
class _Rope(nn.Module):
|
||||
"""Minimal RoPE stub exposing a `freqs` parameter."""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.freqs = nn.Parameter(torch.zeros(4))
|
||||
|
||||
|
||||
class _Block(nn.Module):
|
||||
"""Minimal transformer block stub holding a `_Rope`."""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.rope = _Rope()
|
||||
|
||||
|
||||
class _DiffusionModel(nn.Module):
|
||||
"""Stub diffusion model with N blocks and pos/neg conditioning buffers."""
|
||||
def __init__(self, n_blocks=3, conditioning_dtype=torch.float32):
|
||||
super().__init__()
|
||||
self.blocks = nn.ModuleList([_Block() for _ in range(n_blocks)])
|
||||
@ -102,18 +94,16 @@ class _DiffusionModel(nn.Module):
|
||||
|
||||
|
||||
class _ModelInner:
|
||||
"""Inner model wrapper exposing `.diffusion_model`."""
|
||||
def __init__(self, diffusion_model):
|
||||
self.diffusion_model = diffusion_model
|
||||
|
||||
|
||||
class _ModelPatcher:
|
||||
"""ModelPatcher stub exposing `.model._ModelInner`."""
|
||||
def __init__(self, diffusion_model):
|
||||
self.model = _ModelInner(diffusion_model)
|
||||
|
||||
|
||||
def test_seedvr2_conditioning_schema_exposes_model_passthrough_output():
|
||||
def test_seedvr2_conditioning_schema_exposes_conditioning_outputs():
|
||||
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||
try:
|
||||
schema = nodes_seedvr.SeedVR2Conditioning.define_schema()
|
||||
@ -123,37 +113,50 @@ def test_seedvr2_conditioning_schema_exposes_model_passthrough_output():
|
||||
]
|
||||
assert schema.inputs[1].display_name == "latent"
|
||||
assert [output.display_name for output in schema.outputs] == [
|
||||
"model",
|
||||
"positive",
|
||||
"negative",
|
||||
"latent",
|
||||
]
|
||||
finally:
|
||||
restore()
|
||||
|
||||
|
||||
def test_seedvr2_conditioning_returns_packed_input_latent_deterministically():
|
||||
def test_seedvr2_conditioning_rejects_wrong_latent_channels():
|
||||
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||
try:
|
||||
patcher = _ModelPatcher(_DiffusionModel())
|
||||
vae_conditioning = {"samples": torch.zeros(1, 8, 2, 2, 2)}
|
||||
|
||||
with pytest.raises(ValueError, match=f"{SEEDVR2_LATENT_CHANNELS} channels"):
|
||||
nodes_seedvr.SeedVR2Conditioning.execute(patcher, vae_conditioning)
|
||||
finally:
|
||||
restore()
|
||||
|
||||
|
||||
def test_seedvr2_conditioning_returns_conditioning_deterministically():
|
||||
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||
try:
|
||||
diffusion_model = _DiffusionModel()
|
||||
patcher = _ModelPatcher(diffusion_model)
|
||||
samples = torch.arange(1, 25, dtype=torch.float32).reshape(1, 2, 3, 2, 2)
|
||||
samples = torch.arange(
|
||||
1,
|
||||
1 + SEEDVR2_LATENT_CHANNELS * 3 * 2 * 2,
|
||||
dtype=torch.float32,
|
||||
).reshape(1, SEEDVR2_LATENT_CHANNELS, 3, 2, 2)
|
||||
vae_conditioning = {"samples": samples}
|
||||
|
||||
_, first_positive, first_negative, first_latent = (
|
||||
first_positive, first_negative = (
|
||||
nodes_seedvr.SeedVR2Conditioning.execute(
|
||||
patcher,
|
||||
vae_conditioning,
|
||||
)
|
||||
)
|
||||
_, second_positive, second_negative, second_latent = (
|
||||
second_positive, second_negative = (
|
||||
nodes_seedvr.SeedVR2Conditioning.execute(
|
||||
patcher,
|
||||
vae_conditioning,
|
||||
)
|
||||
)
|
||||
|
||||
expected_latent = samples.reshape(1, 6, 2, 2)
|
||||
channel_last = samples.movedim(1, -1).contiguous()
|
||||
expected_condition = torch.cat(
|
||||
[
|
||||
@ -161,10 +164,8 @@ def test_seedvr2_conditioning_returns_packed_input_latent_deterministically():
|
||||
torch.ones((*channel_last.shape[:-1], 1)),
|
||||
],
|
||||
dim=-1,
|
||||
).movedim(-1, 1).reshape(1, 9, 2, 2)
|
||||
).movedim(-1, 1)
|
||||
|
||||
assert torch.equal(first_latent["samples"], expected_latent)
|
||||
assert torch.equal(second_latent["samples"], expected_latent)
|
||||
assert torch.equal(
|
||||
first_positive[0][1]["condition"],
|
||||
expected_condition,
|
||||
|
||||
@ -201,6 +201,17 @@ class TestModelDetection:
|
||||
del sd["positive_conditioning"]
|
||||
assert model_config_from_unet_config(unet_config, sd) is None
|
||||
|
||||
def test_seedvr2_model_match_normalizes_num_heads(self):
|
||||
sd = _make_seedvr2_7b_shared_mm_sd()
|
||||
unet_config = detect_unet_config(sd, "")
|
||||
unet_config["num_heads"] = unet_config.pop("heads")
|
||||
|
||||
model_config = model_config_from_unet_config(unet_config, sd)
|
||||
|
||||
assert type(model_config).__name__ == "SeedVR2"
|
||||
assert model_config.unet_config["heads"] == 24
|
||||
assert "num_heads" not in model_config.unet_config
|
||||
|
||||
def test_seedvr2_model_match_accepts_full_checkpoint_prefix(self):
|
||||
sd = _add_model_diffusion_prefix(_make_seedvr2_7b_shared_mm_sd())
|
||||
|
||||
|
||||
@ -1,22 +1,6 @@
|
||||
"""Regression: ``comfy.ldm.seedvr.vae.VideoAutoencoderKL.forward`` must
|
||||
honor the actual tensor/tuple return contract of ``encode()`` and
|
||||
``decode_()`` and must NOT dereference diffusers-style ``.latent_dist``
|
||||
or ``.sample`` attributes on those returns.
|
||||
|
||||
The pre-fix body raised ``AttributeError: 'Tensor' object has no
|
||||
attribute 'latent_dist'`` for ``mode in {"encode", "all"}`` and
|
||||
``AttributeError: 'VideoAutoencoderKL' object has no attribute 'decode'``
|
||||
for ``mode == "decode"`` (the class only defines ``decode_`` with a
|
||||
trailing underscore). The post-fix body unwraps the optional one-element
|
||||
tuple shape that ``return_dict=False`` produces and returns the tensor
|
||||
directly.
|
||||
|
||||
Tests construct a stub subclass of ``VideoAutoencoderKL`` that bypasses
|
||||
the heavy ``__init__`` via ``torch.nn.Module.__init__(self)`` and
|
||||
overrides ``encode``/``decode_`` with known tensors so the contract can
|
||||
be probed without loading any real VAE weights.
|
||||
"""
|
||||
"""Regression tests for the SeedVR2 VAE forward return contract."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@ -25,13 +9,13 @@ from comfy.cli_args import args as cli_args
|
||||
if not torch.cuda.is_available():
|
||||
cli_args.cpu = True
|
||||
|
||||
from comfy.ldm.seedvr.vae import VideoAutoencoderKL # noqa: E402
|
||||
from comfy.ldm.seedvr.vae import SEEDVR2_LATENT_CHANNELS, VideoAutoencoderKL # noqa: E402
|
||||
|
||||
|
||||
_LATENT_SHAPE = (1, 16, 2, 2, 2)
|
||||
_LATENT_SHAPE = (1, SEEDVR2_LATENT_CHANNELS, 2, 2, 2)
|
||||
_DECODED_SHAPE = (1, 3, 5, 16, 16)
|
||||
_INPUT_ENCODE_SHAPE = (1, 3, 5, 16, 16)
|
||||
_INPUT_DECODE_SHAPE = (1, 16, 2, 2, 2)
|
||||
_INPUT_DECODE_SHAPE = _LATENT_SHAPE
|
||||
|
||||
|
||||
class _StubVAE(VideoAutoencoderKL):
|
||||
@ -64,8 +48,6 @@ def test_forward_decode_returns_tensor():
|
||||
|
||||
|
||||
class _TupleReturningStubVAE(VideoAutoencoderKL):
|
||||
"""Stub whose ``encode``/``decode_`` return the ``(tensor,)`` tuple of ``return_dict=False``, exercising the unwrap branch of ``VideoAutoencoderKL.forward``."""
|
||||
|
||||
def __init__(self):
|
||||
nn.Module.__init__(self)
|
||||
self._encode_tensor = torch.zeros(*_LATENT_SHAPE)
|
||||
@ -84,3 +66,9 @@ def test_forward_all_unwraps_one_tuple_at_each_step():
|
||||
result = vae.forward(x, mode="all")
|
||||
assert type(result) is torch.Tensor
|
||||
assert result.shape == torch.Size(_DECODED_SHAPE)
|
||||
|
||||
|
||||
def test_forward_rejects_unknown_mode():
|
||||
vae = _StubVAE()
|
||||
with pytest.raises(ValueError, match="Unknown SeedVR2 VAE forward mode"):
|
||||
vae.forward(torch.zeros(*_INPUT_ENCODE_SHAPE), mode="bogus")
|
||||
|
||||
@ -41,8 +41,9 @@ def test_seedvr2_text_conditioning_accepts_cfg1_single_branch():
|
||||
|
||||
def test_seedvr2_vae_decode_memory_covers_full_frame_lab_transfer():
|
||||
wrapper = seedvr_vae.VideoAutoencoderKLWrapper.__new__(seedvr_vae.VideoAutoencoderKLWrapper)
|
||||
estimate = wrapper.comfy_memory_used_decode((1, 16, 26, 120, 160))
|
||||
old_estimate = 16 * 120 * 160 * (4 * 8 * 8) * 2
|
||||
latent_channels = seedvr_vae.SEEDVR2_LATENT_CHANNELS
|
||||
estimate = wrapper.comfy_memory_used_decode((1, latent_channels, 26, 120, 160))
|
||||
old_estimate = latent_channels * 120 * 160 * (4 * 8 * 8) * 2
|
||||
|
||||
assert estimate == 101 * 960 * 1280 * 160
|
||||
assert estimate > 15 * 1024 ** 3
|
||||
|
||||
@ -1,16 +1,4 @@
|
||||
"""Consolidated SeedVR2 internals regression tests.
|
||||
|
||||
Sources (all merged verbatim, helper names disambiguated where colliding):
|
||||
|
||||
* GroupNorm limit gate — causal_norm_wrapper at vae.py:509 must compare
|
||||
memory_occupy against get_norm_limit(), not float('inf').
|
||||
* SeedVR2 variable-length attention split-loop contract.
|
||||
|
||||
Pre-import CPU-only guard is required because comfy.ldm.seedvr.model and
|
||||
comfy.ldm.modules.attention transitively pull in comfy.model_management,
|
||||
which probes torch.cuda.current_device() at import time unless args.cpu is
|
||||
set first.
|
||||
"""
|
||||
"""SeedVR2 internals regression tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -35,10 +23,6 @@ from comfy.ldm.seedvr.vae import ( # noqa: E402
|
||||
from comfy.ldm.seedvr.attention import var_attention_optimized_split # noqa: E402
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GroupNorm limit tests (test_seedvr_groupnorm_limit.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_NUM_CHANNELS = 8
|
||||
_NUM_GROUPS = 4
|
||||
_TENSOR_SHAPE = (1, 8, 2, 4, 4)
|
||||
@ -89,10 +73,6 @@ def test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path(groupnorm_cls):
|
||||
set_norm_limit(None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SeedVR2 var_attention split-loop tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypatch):
|
||||
dim = 8
|
||||
heads = 2
|
||||
@ -140,18 +120,8 @@ def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypa
|
||||
assert call["heads"] == heads
|
||||
assert call["skip_reshape"] is True
|
||||
assert call["skip_output_reshape"] is True
|
||||
torch.testing.assert_close(
|
||||
call["cu_seqlens_q"],
|
||||
torch.tensor([0, 7, 14], dtype=torch.int32),
|
||||
rtol=0,
|
||||
atol=0,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
call["cu_seqlens_k"],
|
||||
torch.tensor([0, 7, 14], dtype=torch.int32),
|
||||
rtol=0,
|
||||
atol=0,
|
||||
)
|
||||
assert call["cu_seqlens_q"] == [0, 7, 14]
|
||||
assert call["cu_seqlens_k"] == [0, 7, 14]
|
||||
|
||||
|
||||
def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatch):
|
||||
@ -160,7 +130,7 @@ def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatc
|
||||
q = torch.arange(30, dtype=torch.float32).reshape(5, heads, head_dim)
|
||||
k = q + 100
|
||||
v = q + 200
|
||||
cu = torch.tensor([0, 2, 5], dtype=torch.int32)
|
||||
cu = [0, 2, 5]
|
||||
calls = []
|
||||
|
||||
def fake_optimized_attention(q_arg, k_arg, v_arg, heads_arg, **kwargs):
|
||||
@ -197,20 +167,3 @@ def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatc
|
||||
assert all(call["kwargs"]["skip_output_reshape"] is True for call in calls)
|
||||
torch.testing.assert_close(out, q + v, rtol=0, atol=0)
|
||||
|
||||
|
||||
def test_var_attention_optimized_split_rejects_bad_offsets():
|
||||
q = torch.randn(5, 2, 3)
|
||||
cu_bad = torch.tensor([0, 2, 6], dtype=torch.int32)
|
||||
cu_ok = torch.tensor([0, 2, 5], dtype=torch.int32)
|
||||
|
||||
with pytest.raises(ValueError, match="cu_seqlens_q does not match token count"):
|
||||
var_attention_optimized_split(
|
||||
q,
|
||||
q,
|
||||
q,
|
||||
2,
|
||||
cu_bad,
|
||||
cu_ok,
|
||||
skip_reshape=True,
|
||||
skip_output_reshape=True,
|
||||
)
|
||||
|
||||
@ -1,17 +1,10 @@
|
||||
"""Consolidated SeedVR2 model/graph/forward regression tests.
|
||||
|
||||
Merged from:
|
||||
- seedvr_model_test.py
|
||||
- test_seedvr_7b_final_block_text_path.py
|
||||
- test_seedvr_forward_no_device_cast.py
|
||||
- test_seedvr_latent_format.py
|
||||
- test_seedvr2_vae_graph_boundaries.py
|
||||
"""
|
||||
"""SeedVR2 model, latent-format, and VAE graph regression tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@ -22,7 +15,6 @@ if not torch.cuda.is_available():
|
||||
|
||||
import comfy # noqa: E402
|
||||
import comfy.latent_formats # noqa: E402
|
||||
import comfy.ldm.seedvr.model # noqa: E402
|
||||
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
|
||||
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
|
||||
import comfy.model_management # noqa: E402
|
||||
@ -33,9 +25,7 @@ import nodes as nodes_mod # noqa: E402
|
||||
from comfy.ldm.seedvr.model import NaDiT # noqa: E402
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers from seedvr_model_test.py
|
||||
# ---------------------------------------------------------------------------
|
||||
_LATENT_CHANNELS = seedvr_vae_mod.SEEDVR2_LATENT_CHANNELS
|
||||
|
||||
|
||||
def _make_standin(positive_conditioning):
|
||||
@ -51,11 +41,6 @@ def _make_standin(positive_conditioning):
|
||||
return _StandIn()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers from test_seedvr_7b_final_block_text_path.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _StubModule(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
@ -88,11 +73,6 @@ def _capture_last_layer_flags(monkeypatch, vid_dim: int, txt_in_dim: int) -> lis
|
||||
return flags
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers from test_seedvr_latent_format.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _Model:
|
||||
def __init__(self, latent_format):
|
||||
self._latent_format = latent_format
|
||||
@ -102,11 +82,6 @@ class _Model:
|
||||
return self._latent_format
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers from test_seedvr2_vae_graph_boundaries.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _Patcher:
|
||||
def get_free_memory(self, device):
|
||||
return 1024 * 1024 * 1024
|
||||
@ -136,14 +111,14 @@ class _DecodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
|
||||
self.calls.append({"shape": tuple(z.shape), "seedvr2_tiling": seedvr2_tiling})
|
||||
if z.ndim == 4:
|
||||
b, tc, h, w = z.shape
|
||||
t = tc // 16
|
||||
t = tc // _LATENT_CHANNELS
|
||||
else:
|
||||
b, _, t, h, w = z.shape
|
||||
return torch.zeros(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device)
|
||||
|
||||
|
||||
def test_seedvr2_wrapper_public_encode_returns_tensor(monkeypatch):
|
||||
raw_latent = torch.full((1, 16, 1, 4, 5), 2.0)
|
||||
raw_latent = torch.full((1, _LATENT_CHANNELS, 1, 4, 5), 2.0)
|
||||
seen_shapes = []
|
||||
|
||||
def base_encode(self, x):
|
||||
@ -159,12 +134,12 @@ def test_seedvr2_wrapper_public_encode_returns_tensor(monkeypatch):
|
||||
latent = vae.encode(torch.zeros(1, 3, 32, 40))
|
||||
|
||||
assert type(latent) is torch.Tensor
|
||||
assert tuple(latent.shape) == (1, 16, 4, 5)
|
||||
assert tuple(latent.shape) == (1, _LATENT_CHANNELS, 4, 5)
|
||||
assert seen_shapes == [(1, 3, 1, 32, 40)]
|
||||
|
||||
|
||||
def test_seedvr2_wrapper_private_encode_helper_keeps_raw_latent(monkeypatch):
|
||||
raw_latent = torch.full((1, 16, 1, 4, 5), 3.0)
|
||||
raw_latent = torch.full((1, _LATENT_CHANNELS, 1, 4, 5), 3.0)
|
||||
|
||||
def base_encode(self, x):
|
||||
return raw_latent.to(device=x.device, dtype=x.dtype)
|
||||
@ -177,8 +152,8 @@ def test_seedvr2_wrapper_private_encode_helper_keeps_raw_latent(monkeypatch):
|
||||
|
||||
latent, raw = vae._encode_with_raw_latent(torch.zeros(1, 3, 32, 40))
|
||||
|
||||
assert tuple(latent.shape) == (1, 16, 4, 5)
|
||||
assert tuple(raw.shape) == (1, 16, 1, 4, 5)
|
||||
assert tuple(latent.shape) == (1, _LATENT_CHANNELS, 4, 5)
|
||||
assert tuple(raw.shape) == (1, _LATENT_CHANNELS, 1, 4, 5)
|
||||
assert torch.equal(raw, raw_latent)
|
||||
|
||||
|
||||
@ -188,7 +163,7 @@ def _make_vae(wrapper):
|
||||
vae.device = torch.device("cpu")
|
||||
vae.output_device = torch.device("cpu")
|
||||
vae.vae_dtype = torch.float32
|
||||
vae.latent_channels = 16
|
||||
vae.latent_channels = _LATENT_CHANNELS
|
||||
vae.latent_dim = 3
|
||||
vae.downscale_ratio = (lambda a: max(0, (a + 3) // 4), 8, 8)
|
||||
vae.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||
@ -212,13 +187,7 @@ def _make_vae(wrapper):
|
||||
return vae
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests from seedvr_model_test.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_missing_context_falls_back_to_positive_buffer():
|
||||
"""``context is None`` falls back to the registered ``positive_conditioning`` buffer and runs to completion."""
|
||||
pos_buffer = torch.full((58, 5120), 7.0)
|
||||
standin = _make_standin(pos_buffer)
|
||||
txt, txt_shape = standin._resolve_text_conditioning(None)
|
||||
@ -231,11 +200,6 @@ def test_missing_context_falls_back_to_positive_buffer():
|
||||
assert txt_shape[0, 0].item() == 58
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests from test_seedvr_7b_final_block_text_path.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_seedvr2_7b_keeps_final_block_text_path(monkeypatch):
|
||||
assert _capture_last_layer_flags(monkeypatch, vid_dim=3072, txt_in_dim=3072) == [
|
||||
False,
|
||||
@ -268,43 +232,49 @@ def test_seedvr2_7b_rope3d_matches_wrapper_oracle():
|
||||
torch.testing.assert_close(actual_k, expected_k, rtol=0, atol=0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests from test_seedvr_latent_format.py
|
||||
# ---------------------------------------------------------------------------
|
||||
def test_seedvr2_forward_requires_conditioning_latents():
|
||||
model = NaDiT.__new__(NaDiT)
|
||||
x = torch.zeros(1, _LATENT_CHANNELS, 1, 4, 5)
|
||||
|
||||
with pytest.raises(ValueError, match="requires conditioning latents"):
|
||||
NaDiT.forward(model, x, timestep=torch.tensor([1.0]), context=None)
|
||||
|
||||
|
||||
def test_seedvr2_latent_format_uses_16_channels_without_3d_empty_latent_expansion():
|
||||
def test_seedvr2_latent_format_uses_native_video_latent_shape():
|
||||
latent_format = comfy.latent_formats.SeedVR2()
|
||||
latent_image = torch.zeros(1, 1, 4, 5)
|
||||
|
||||
fixed = comfy.sample.fix_empty_latent_channels(_Model(latent_format), latent_image)
|
||||
|
||||
assert latent_format.latent_channels == 16
|
||||
assert latent_format.latent_dimensions == 2
|
||||
assert fixed.shape == (1, 16, 4, 5)
|
||||
assert latent_format.latent_channels == _LATENT_CHANNELS
|
||||
assert latent_format.latent_dimensions == 3
|
||||
assert fixed.shape == (1, _LATENT_CHANNELS, 1, 4, 5)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests from test_seedvr2_vae_graph_boundaries.py
|
||||
# ---------------------------------------------------------------------------
|
||||
def test_seedvr2_model_requires_native_5d_latent():
|
||||
latent = torch.zeros(1, _LATENT_CHANNELS, 2, 4, 5)
|
||||
assert NaDiT._check_seedvr2_video_latent(latent, _LATENT_CHANNELS, "latent") is latent
|
||||
|
||||
with pytest.raises(ValueError, match="5-D native latent"):
|
||||
NaDiT._check_seedvr2_video_latent(torch.zeros(1, _LATENT_CHANNELS * 2, 4, 5), _LATENT_CHANNELS, "latent")
|
||||
|
||||
|
||||
def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeypatch):
|
||||
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
||||
|
||||
encoded = torch.full((1, 16, 2, 4, 5), 2.0)
|
||||
encoded = torch.full((1, _LATENT_CHANNELS, 2, 4, 5), 2.0)
|
||||
vae = _make_vae(_EncodeWrapper(encoded))
|
||||
pixels = torch.zeros(1, 5, 32, 40, 3)
|
||||
|
||||
node_output = nodes_mod.VAEEncode().encode(vae, pixels)[0]
|
||||
node_latent = node_output["samples"]
|
||||
assert set(node_output) == {"samples"}
|
||||
assert tuple(node_latent.shape) == (1, 16, 2, 4, 5)
|
||||
assert tuple(node_latent.shape) == (1, _LATENT_CHANNELS, 2, 4, 5)
|
||||
assert node_latent.dtype == torch.float32
|
||||
assert node_latent.stride()[-1] == 1
|
||||
assert torch.equal(node_latent, torch.full_like(node_latent, 2.0 * 0.9152))
|
||||
assert torch.equal(node_latent, torch.full_like(node_latent, 2.0 * seedvr_vae_mod.BYTEDANCE_VAE_SCALING_FACTOR))
|
||||
|
||||
tiled = torch.full((1, 16, 2, 4, 5), 3.0)
|
||||
tiled = torch.full((1, _LATENT_CHANNELS, 2, 4, 5), 3.0)
|
||||
monkeypatch.setattr(seedvr_vae_mod, "tiled_vae", MagicMock(return_value=tiled))
|
||||
tiled_output = nodes_mod.VAEEncodeTiled().encode(
|
||||
vae,
|
||||
@ -316,9 +286,9 @@ def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeyp
|
||||
)[0]
|
||||
tiled_latent = tiled_output["samples"]
|
||||
assert set(tiled_output) == {"samples"}
|
||||
assert tuple(tiled_latent.shape) == (1, 16, 2, 4, 5)
|
||||
assert tuple(tiled_latent.shape) == (1, _LATENT_CHANNELS, 2, 4, 5)
|
||||
assert tiled_latent.dtype == torch.float32
|
||||
assert torch.equal(tiled_latent, torch.full_like(tiled_latent, 3.0 * 0.9152))
|
||||
assert torch.equal(tiled_latent, torch.full_like(tiled_latent, 3.0 * seedvr_vae_mod.BYTEDANCE_VAE_SCALING_FACTOR))
|
||||
|
||||
|
||||
def test_vaedecode_tiled_spatial_applies_temporal_discarded(monkeypatch):
|
||||
@ -327,7 +297,7 @@ def test_vaedecode_tiled_spatial_applies_temporal_discarded(monkeypatch):
|
||||
|
||||
nodes_mod.VAEDecodeTiled().decode(
|
||||
vae,
|
||||
{"samples": torch.zeros(1, 16, 2, 4, 5)},
|
||||
{"samples": torch.zeros(1, _LATENT_CHANNELS, 2, 4, 5)},
|
||||
tile_size=512,
|
||||
overlap=64,
|
||||
temporal_size=16,
|
||||
@ -339,7 +309,7 @@ def test_vaedecode_tiled_spatial_applies_temporal_discarded(monkeypatch):
|
||||
# knobs are no-ops at the wrapper.
|
||||
assert vae.first_stage_model.calls == [
|
||||
{
|
||||
"shape": (1, 16, 2, 4, 5),
|
||||
"shape": (1, _LATENT_CHANNELS, 2, 4, 5),
|
||||
"seedvr2_tiling": {
|
||||
"enable_tiling": True,
|
||||
"tile_size": (512, 512),
|
||||
|
||||
@ -13,6 +13,9 @@ import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
|
||||
from comfy_extras import nodes_seedvr # noqa: E402
|
||||
|
||||
|
||||
_LATENT_CHANNELS = vae_mod.SEEDVR2_LATENT_CHANNELS
|
||||
|
||||
|
||||
def _make_wrapper() -> vae_mod.VideoAutoencoderKLWrapper:
|
||||
wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__(
|
||||
vae_mod.VideoAutoencoderKLWrapper
|
||||
@ -40,7 +43,7 @@ def _decode_with_patches(wrapper, z):
|
||||
def test_decode_b2_t3_multi_frame_batch_unchanged():
|
||||
wrapper = _make_wrapper()
|
||||
|
||||
out = _decode_with_patches(wrapper, torch.zeros(2, 16 * 3, 2, 2))
|
||||
out = _decode_with_patches(wrapper, torch.zeros(2, _LATENT_CHANNELS * 3, 2, 2))
|
||||
|
||||
assert tuple(out.shape) == (2, 3, 3, 16, 16)
|
||||
|
||||
@ -62,17 +65,17 @@ def test_seedvr2_wrapper_decode_accepts_5d_channel_first_latents_without_preproc
|
||||
wrapper = _Wrapper()
|
||||
|
||||
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub):
|
||||
out = wrapper.decode(torch.zeros(1, 16, 2, 4, 5))
|
||||
out = wrapper.decode(torch.zeros(1, _LATENT_CHANNELS, 2, 4, 5))
|
||||
|
||||
assert tuple(out.shape) == (1, 3, 2, 32, 40)
|
||||
assert wrapper.calls == [(1, 16, 2, 4, 5)]
|
||||
assert wrapper.calls == [(1, _LATENT_CHANNELS, 2, 4, 5)]
|
||||
|
||||
|
||||
def test_seedvr2_wrapper_decode_rejects_wrong_rank_latents():
|
||||
wrapper = _Wrapper()
|
||||
|
||||
with pytest.raises(RuntimeError, match=r"latent input must be 4-D collapsed .* or 5-D"):
|
||||
wrapper.decode(torch.zeros(1, 16, 4))
|
||||
wrapper.decode(torch.zeros(1, _LATENT_CHANNELS, 4))
|
||||
|
||||
|
||||
def _t_padded(t_in: int) -> int:
|
||||
|
||||
@ -16,9 +16,7 @@ import comfy.sd as sd_mod # noqa: E402
|
||||
from comfy.ldm.seedvr.vae import MemoryState, tiled_vae # noqa: E402
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# From test_seedvr_vae_tiled_decode_latent_min_size_override.py
|
||||
# ---------------------------------------------------------------------------
|
||||
_LATENT_CHANNELS = seedvr_vae_mod.SEEDVR2_LATENT_CHANNELS
|
||||
|
||||
|
||||
def test_runtime_decode_zero_temporal_size_disables_slicing_for_call():
|
||||
@ -44,7 +42,7 @@ def test_runtime_decode_zero_temporal_size_disables_slicing_for_call():
|
||||
return torch.zeros((b, 3, d, h * 8, w * 8), dtype=z.dtype)
|
||||
|
||||
vae = StubVAEModel()
|
||||
z = torch.zeros((1, 16, 5, 8, 8), dtype=torch.float32)
|
||||
z = torch.zeros((1, _LATENT_CHANNELS, 5, 8, 8), dtype=torch.float32)
|
||||
|
||||
tiled_vae(
|
||||
z,
|
||||
@ -61,11 +59,6 @@ def test_runtime_decode_zero_temporal_size_disables_slicing_for_call():
|
||||
assert vae.slicing_latent_min_size == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# From test_seedvr_vae_tiled_encode_runt_slice_override.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_zero_temporal_size_preserves_min_size_when_encode_raises():
|
||||
class RaisingVAEModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -110,7 +103,7 @@ def test_tiled_vae_encode_uses_tensor_return_without_indexing():
|
||||
def encode(self, t_chunk):
|
||||
self.calls.append(tuple(t_chunk.shape))
|
||||
b, _, _, h, w = t_chunk.shape
|
||||
return torch.ones((b, 16, 1, h // 8, w // 8), dtype=t_chunk.dtype)
|
||||
return torch.ones((b, _LATENT_CHANNELS, 1, h // 8, w // 8), dtype=t_chunk.dtype)
|
||||
|
||||
vae = TensorEncodeVAEModel()
|
||||
x = torch.zeros((2, 3, 1, 64, 64), dtype=torch.float32)
|
||||
@ -126,12 +119,34 @@ def test_tiled_vae_encode_uses_tensor_return_without_indexing():
|
||||
)
|
||||
|
||||
assert vae.calls == [(2, 3, 1, 64, 64)]
|
||||
assert tuple(out.shape) == (2, 16, 1, 8, 8)
|
||||
assert tuple(out.shape) == (2, _LATENT_CHANNELS, 1, 8, 8)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# From test_seedvr_vae_tiled_temporal_slicing.py
|
||||
# ---------------------------------------------------------------------------
|
||||
def test_tiled_vae_preserves_input_dtype_on_single_tile():
|
||||
class FloatOutputVAEModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.slicing_sample_min_size = 4
|
||||
self.spatial_downsample_factor = 8
|
||||
self.temporal_downsample_factor = 4
|
||||
self.device = torch.device("cpu")
|
||||
self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
|
||||
|
||||
def encode(self, t_chunk):
|
||||
b, _, _, h, w = t_chunk.shape
|
||||
return torch.ones((b, _LATENT_CHANNELS, 1, h // 8, w // 8), dtype=torch.float32)
|
||||
|
||||
out = tiled_vae(
|
||||
torch.zeros((1, 3, 1, 64, 64), dtype=torch.float16),
|
||||
FloatOutputVAEModel(),
|
||||
tile_size=(64, 64),
|
||||
tile_overlap=(0, 0),
|
||||
temporal_size=0,
|
||||
temporal_overlap=0,
|
||||
encode=True,
|
||||
)
|
||||
|
||||
assert out.dtype == torch.float16
|
||||
|
||||
|
||||
class _SlicingDecodeVAE(nn.Module):
|
||||
@ -164,7 +179,10 @@ class _SlicingDecodeVAE(nn.Module):
|
||||
|
||||
def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size():
|
||||
vae = _SlicingDecodeVAE(slicing_latent_min_size=2)
|
||||
z = torch.arange(1 * 16 * 5 * 8 * 8, dtype=torch.float32).reshape(1, 16, 5, 8, 8)
|
||||
z = torch.arange(
|
||||
_LATENT_CHANNELS * 5 * 8 * 8,
|
||||
dtype=torch.float32,
|
||||
).reshape(1, _LATENT_CHANNELS, 5, 8, 8)
|
||||
|
||||
tiled_vae(
|
||||
z,
|
||||
@ -199,16 +217,11 @@ def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size():
|
||||
return torch.zeros(1, 3, 1, 16, 16)
|
||||
|
||||
with patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae):
|
||||
wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling=seedvr2_tiling)
|
||||
wrapper.decode(torch.zeros(1, _LATENT_CHANNELS, 2, 2), seedvr2_tiling=seedvr2_tiling)
|
||||
|
||||
assert captured["temporal_overlap"] == 7
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# From test_vae_decode_tiled_dispatcher_seedvr2_4d.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _force_oom(*a, **k):
|
||||
raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test")
|
||||
|
||||
@ -256,10 +269,10 @@ def _dispatch(vae, samples, seedvr2_call, generic_call, patch_wrapper_decode):
|
||||
def test_4d_seedvr2_latent_routes_to_owned_decode_tiled():
|
||||
wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(
|
||||
seedvr_vae_mod.VideoAutoencoderKLWrapper)
|
||||
vae = _make_vae(wrapper, latent_channels=16, latent_dim=3)
|
||||
vae = _make_vae(wrapper, latent_channels=_LATENT_CHANNELS, latent_dim=3)
|
||||
seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64))
|
||||
generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64))
|
||||
_dispatch(vae, torch.zeros(1, 16 * 3, 8, 8), seedvr2_call, generic_call, True)
|
||||
_dispatch(vae, torch.zeros(1, _LATENT_CHANNELS * 3, 8, 8), seedvr2_call, generic_call, True)
|
||||
assert seedvr2_call.call_count == 1
|
||||
assert generic_call.call_count == 0
|
||||
|
||||
@ -275,11 +288,6 @@ def test_4d_non_seedvr2_latent_still_routes_to_generic_decode_tiled():
|
||||
assert seedvr2_call.call_count == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# From test_vae_encode_tiled_fallback_dispatcher_seedvr2.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _populate_common_vae_attrs_fallback(vae):
|
||||
vae.patcher = MagicMock()
|
||||
vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024)
|
||||
@ -291,7 +299,7 @@ def _populate_common_vae_attrs_fallback(vae):
|
||||
vae.upscale_ratio = 8
|
||||
vae.upscale_index_formula = None
|
||||
vae.output_channels = 3
|
||||
vae.latent_channels = 16
|
||||
vae.latent_channels = _LATENT_CHANNELS
|
||||
vae.latent_dim = 3
|
||||
vae.downscale_ratio = 8
|
||||
vae.downscale_index_formula = None
|
||||
@ -334,8 +342,8 @@ def test_seedvr2_3d_routes_to_owned_encode_tiled_on_oom():
|
||||
vae = _make_seedvr2_vae_fallback()
|
||||
pixel_samples = torch.zeros((1, 8, 64, 64, 3))
|
||||
|
||||
seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
|
||||
generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
|
||||
seedvr2_call = MagicMock(return_value=torch.zeros(1, _LATENT_CHANNELS, 2, 8, 8))
|
||||
generic_call = MagicMock(return_value=torch.zeros(1, _LATENT_CHANNELS, 2, 8, 8))
|
||||
|
||||
with patch.object(sd_mod.model_management, "raise_non_oom",
|
||||
lambda e: None), \
|
||||
@ -363,7 +371,7 @@ def test_non_seedvr2_encode_tiled_3d_default_overlap_is_concrete():
|
||||
vae = _make_non_seedvr2_vae_fallback()
|
||||
vae.downscale_ratio = (lambda a: max(1, a // 4), 8, 8)
|
||||
vae.upscale_ratio = (lambda a: a * 4, 8, 8)
|
||||
generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
|
||||
generic_call = MagicMock(return_value=torch.zeros(1, _LATENT_CHANNELS, 2, 8, 8))
|
||||
pixel_samples = torch.zeros((1, 8, 64, 64, 3))
|
||||
|
||||
with patch.object(sd_mod.model_management, "load_models_gpu",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user