From cd18c4460a2a38d1cb8d9dd4f468513a38f7d3f8 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Thu, 11 Jun 2026 10:39:39 -0500 Subject: [PATCH 01/12] Add SeedVR2 model support --- comfy/latent_formats.py | 3 + comfy/ldm/modules/diffusionmodules/model.py | 6 +- comfy/ldm/seedvr/attention.py | 77 + comfy/ldm/seedvr/constants.py | 72 + comfy/ldm/seedvr/model.py | 1487 +++++++++++++++++++ comfy/model_base.py | 11 + comfy/model_detection.py | 47 + comfy/supported_models.py | 30 + comfy/supported_models_base.py | 2 +- 9 files changed, 1732 insertions(+), 3 deletions(-) create mode 100644 comfy/ldm/seedvr/attention.py create mode 100644 comfy/ldm/seedvr/constants.py create mode 100644 comfy/ldm/seedvr/model.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index bbdfd4bc2..fc5b13c21 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -779,6 +779,9 @@ class ACEAudio(LatentFormat): latent_channels = 8 latent_dimensions = 2 +class SeedVR2(LatentFormat): + latent_channels = 16 + class ACEAudio15(LatentFormat): latent_channels = 64 latent_dimensions = 1 diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index fcbaa074f..e752d0ecb 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -22,7 +22,7 @@ def torch_cat_if_needed(xl, dim): else: return None -def get_timestep_embedding(timesteps, embedding_dim): +def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1): """ This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. @@ -33,11 +33,13 @@ def get_timestep_embedding(timesteps, embedding_dim): assert len(timesteps.shape) == 1 half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) + emb = math.log(10000) / (half_dim - downscale_freq_shift) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) emb = emb.to(device=timesteps.device) emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0,1,0,0)) return emb diff --git a/comfy/ldm/seedvr/attention.py b/comfy/ldm/seedvr/attention.py new file mode 100644 index 000000000..29ffded38 --- /dev/null +++ b/comfy/ldm/seedvr/attention.py @@ -0,0 +1,77 @@ +import torch + +from comfy.ldm.modules import attention as _attention + + +def _var_attention_qkv(q, k, v, heads, skip_reshape): + if skip_reshape: + return q, k, v, q.shape[-1] + total_tokens, embed_dim = q.shape + head_dim = embed_dim // heads + return ( + q.view(total_tokens, heads, head_dim), + k.view(k.shape[0], heads, head_dim), + v.view(v.shape[0], heads, head_dim), + head_dim, + ) + + +def _var_attention_output(out, heads, head_dim, skip_output_reshape): + if skip_output_reshape: + return out + return out.reshape(-1, heads * head_dim) + + +def _validate_split_cu_seqlens(name, cu_seqlens, token_count): + if cu_seqlens.dtype not in (torch.int32, torch.int64): + raise ValueError(f"{name} must use an integer dtype") + if cu_seqlens.ndim != 1 or cu_seqlens.numel() < 2: + raise ValueError(f"{name} must be a 1D tensor with at least two offsets") + if cu_seqlens[0].item() != 0: + raise ValueError(f"{name} must start at 0") + if (cu_seqlens[1:] <= cu_seqlens[:-1]).any().item(): + raise ValueError(f"{name} must be strictly increasing") + if cu_seqlens[-1].item() != token_count: + raise ValueError(f"{name} does not match token count") + + +def _split_indices(cu_seqlens): + return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long) + + +def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs): + q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape) + + _validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0]) + _validate_split_cu_seqlens("cu_seqlens_k", cu_seqlens_k, k.shape[0]) + if cu_seqlens_k[-1].item() != v.shape[0]: + raise ValueError("cu_seqlens_k does not match v token count") + + q_split_indices = _split_indices(cu_seqlens_q) + k_split_indices = _split_indices(cu_seqlens_k) + q_splits = torch.tensor_split(q, q_split_indices, dim=0) + k_splits = torch.tensor_split(k, k_split_indices, dim=0) + v_splits = torch.tensor_split(v, k_split_indices, dim=0) + if len(q_splits) != len(k_splits) or len(q_splits) != len(v_splits): + raise ValueError("cu_seqlens_q and cu_seqlens_k must describe the same sequence count") + + out = [] + for q_i, k_i, v_i in zip(q_splits, k_splits, v_splits): + q_i = q_i.permute(1, 0, 2).unsqueeze(0) + k_i = k_i.permute(1, 0, 2).unsqueeze(0) + v_i = v_i.permute(1, 0, 2).unsqueeze(0) + out_dtype = q_i.dtype + if _attention.optimized_attention is _attention.attention_sage and q_i.dtype not in (torch.float16, torch.bfloat16): + q_i = q_i.to(torch.bfloat16) + k_i = k_i.to(torch.bfloat16) + v_i = v_i.to(torch.bfloat16) + out_i = _attention.optimized_attention(q_i, k_i, v_i, heads, skip_reshape=True, skip_output_reshape=True) + if out_i.dtype != out_dtype: + out_i = out_i.to(out_dtype) + out.append(out_i.squeeze(0).permute(1, 0, 2)) + + out = torch.cat(out, dim=0) + return _var_attention_output(out, heads, head_dim, skip_output_reshape) + + +optimized_var_attention = var_attention_optimized_split diff --git a/comfy/ldm/seedvr/constants.py b/comfy/ldm/seedvr/constants.py new file mode 100644 index 000000000..71b71d4ad --- /dev/null +++ b/comfy/ldm/seedvr/constants.py @@ -0,0 +1,72 @@ +"""Named constants for the SeedVR2 integration, grouped by provenance. + +Provenance prefixes: +- ``SEEDVR2_*`` - introduced by this integration (no external origin); rationale inline. +- ``BYTEDANCE_*`` - ported from the official ByteDance-Seed/SeedVR release; each cites + the upstream config/source path it was lifted from. +- unprefixed standards (``ROPE_THETA``, ``CIELAB_*``, ``D65_*``) - published literature / + ISO / CIE values; cite the standard. +""" + +# -------------------------------------------------------------------------------------- +# A. Progressive-sampler chunk-size law (SEEDVR2 - this integration's VRAM experiment) +# n_max(frames/chunk) = SEEDVR2_CHUNK_FRAMES_PER_GB * (free_GB - SEEDVR2_CHUNK_GB_MARGIN) +# rounded to the 4n+1 grid. Fit on 22 blocked-5090 cells, validated on a real RTX 4070 +# (3b and 7b). Resolution-independent (the VAE tiling sets the wall, not the DiT). +# -------------------------------------------------------------------------------------- +SEEDVR2_CHUNK_GB_MARGIN = 3 # fixed VRAM overhead before chunks scale (GiB) +SEEDVR2_CHUNK_FRAMES_PER_GB = 4 # empirical slope: pixel frames admitted per free GiB + +# -------------------------------------------------------------------------------------- +# B. Fork heuristics (SEEDVR2 - this integration) +# -------------------------------------------------------------------------------------- +SEEDVR2_7B_VID_DIM = 3072 # runtime 3b-vs-7b sentinel; tested against vid_dim. + # (3072 is ByteDance's 7b vid_dim; the sentinel use is ours.) +SEEDVR2_OOM_BACKOFF_DIVISOR = 2 # auto-chunk OOM retry: halve the chunk and retry. +SEEDVR2_DTYPE_BYTES_FLOOR = 4 # per-element byte floor for memory math (fp32 worst case). +SEEDVR2_7B_MLP_CHUNK = 8192 # 7b MLP token-chunk to bound peak VRAM. +SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS = 4096 # partial-RoPE application token-chunk. +SEEDVR2_LATENT_CHANNELS = 16 # SeedVR2 latent channel count (== BYTEDANCE latent_channels). +SEEDVR2_COND_CHANNELS = 17 # conditioning channels = vid_in_channels(33) - latent(16). + +# Color-correction memory model (fork tuning; per-frame VRAM estimate for chunk sizing) +SEEDVR2_COLOR_MEM_HEADROOM = 0.75 # fraction of free VRAM usable per color-correction chunk. +SEEDVR2_LAB_SCALE_MULTIPLIER = 13 # per-frame byte multiplier, LAB path. +SEEDVR2_WAVELET_SCALE_MULTIPLIER = 10 # per-frame byte multiplier, wavelet path. +SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 # per-frame byte multiplier, AdaIN path. + +# -------------------------------------------------------------------------------------- +# C. ByteDance config / source (BYTEDANCE - cite ByteDance-Seed/SeedVR) +# -------------------------------------------------------------------------------------- +BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57 (scaling_factor); latent denorm. +BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 # infer.py (shifting_factor default); latent denorm shift. +BYTEDANCE_VAE_CONV_MEM_GIB = 0.5 # configs_3b/main.yaml:54 (conv_max_mem). +BYTEDANCE_VAE_NORM_MEM_GIB = 0.5 # configs_3b/main.yaml:55 (norm_max_mem). +BYTEDANCE_LOGVAR_CLAMP_MIN = -30.0 # video_vae_v3/modules/types.py:28. +BYTEDANCE_LOGVAR_CLAMP_MAX = 20.0 # video_vae_v3/modules/types.py:28. +BYTEDANCE_GN_CHUNKS_FP16 = 4 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp16). +BYTEDANCE_GN_CHUNKS_FP32 = 2 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp32). +BYTEDANCE_BLOCK_OUT_CHANNELS = (128, 256, 512, 512) # s8_c16_t4_inflation_sd3.yaml:7-11. +BYTEDANCE_SLICING_SAMPLE_MIN = 4 # s8_c16_t4_inflation_sd3.yaml:22 (slicing_sample_min_size). +BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE = 4 # infer.py:230 (temporal_downsample_factor); the 4n+1 factor. +BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE = 8 # infer.py:231 (spatial_downsample_factor). +BYTEDANCE_720P_REF_AREA = 45 * 80 # dit_v2/window.py:32 (720p reference area for window scaling). +BYTEDANCE_MAX_TEMPORAL_WINDOW = 30 # dit_v2/window.py:35 (max temporal window frames). +BYTEDANCE_ROPE_MAX_FREQ = 256 # dit_v2/rope.py:31 (pixel-RoPE max frequency). +BYTEDANCE_SINUSOIDAL_DIM = 256 # dit_3b/nadit.py:120 (timestep sinusoidal embed dim). + +# -------------------------------------------------------------------------------------- +# D. Published standards (cite the literature) +# -------------------------------------------------------------------------------------- +ROPE_THETA = 10000 # RoPE base; Su et al., "RoFormer", arXiv:2104.09864. + +# CIELAB f(t) piecewise constants and D65 white point (CIE 15 colorimetry; CIE D65). +CIELAB_DELTA = 6.0 / 29.0 # CIE 15 (delta). +CIELAB_KAPPA = (29.0 / 3.0) ** 3 # CIE 15 (kappa). +D65_WHITE_X = 0.95047 # CIE D65 standard illuminant Xn (Yn = 1). +D65_WHITE_Z = 1.08883 # CIE D65 standard illuminant Zn. +WAVELET_DECOMP_LEVELS = 5 # wavelet color-fix decomposition depth (GIMP/Krita; StableSR). + +# NOTE: the sRGB<->XYZ D65 3x3 matrices (IEC 61966-2-1) remain inline in the color code and +# are named (SRGB_TO_XYZ_D65 / XYZ_TO_SRGB_D65) during the color-module extraction, where the +# exact existing coefficients move verbatim rather than being retyped here. diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py new file mode 100644 index 000000000..e7d3deb35 --- /dev/null +++ b/comfy/ldm/seedvr/model.py @@ -0,0 +1,1487 @@ +from dataclasses import dataclass +from typing import Optional, Tuple, Union, List, Dict, Any, Callable +import einops +from einops import rearrange +import torch.nn.functional as F +from math import ceil, pi +import torch +from itertools import chain +from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding +from comfy.ldm.seedvr.attention import optimized_var_attention +from torch.nn.modules.utils import _triple +from torch import nn +import math +from comfy.ldm.flux.math import apply_rope1 +from comfy.ldm.seedvr.constants import ( + BYTEDANCE_720P_REF_AREA, + BYTEDANCE_MAX_TEMPORAL_WINDOW, + BYTEDANCE_ROPE_MAX_FREQ, + BYTEDANCE_SINUSOIDAL_DIM, + ROPE_THETA, + SEEDVR2_7B_MLP_CHUNK, + SEEDVR2_7B_VID_DIM, + SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS, +) +import comfy.model_management +import numbers + +def _torch_float8_types(): + return tuple( + getattr(torch, name) + for name in ( + "float8_e4m3fn", + "float8_e4m3fnuz", + "float8_e5m2", + "float8_e5m2fnuz", + "float8_e8m0fnu", + ) + if hasattr(torch, name) + ) + +class CustomRMSNorm(nn.Module): + + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, device=None, dtype=None): + super(CustomRMSNorm, self).__init__() + + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = torch.Size(normalized_shape) + self.eps = eps + self.elementwise_affine = elementwise_affine + + if self.elementwise_affine: + self.weight = nn.Parameter(torch.ones(*normalized_shape, device=device, dtype=dtype)) + else: + self.register_parameter('weight', None) + + def forward(self, input): + + dims = tuple(range(-len(self.normalized_shape), 0)) + + # Norm statistics in fp32 (fp16 variance underflows); activations return + # in the input dtype so downstream linears run at the model compute dtype. + normalized = input.float() + variance = normalized.pow(2).mean(dim=dims, keepdim=True) + rms = torch.sqrt(variance + self.eps) + + normalized = normalized / rms + + if self.elementwise_affine: + return (normalized * self.weight.to(torch.float32)).to(input.dtype) + return normalized.to(input.dtype) + +class Cache: + def __init__(self, disable=False, prefix="", cache=None): + self.cache = cache if cache is not None else {} + self.disable = disable + self.prefix = prefix + + def __call__(self, key: str, fn: Callable): + if self.disable: + return fn() + + key = self.prefix + key + try: + result = self.cache[key] + except KeyError: + result = fn() + self.cache[key] = result + return result + + def namespace(self, namespace: str): + return Cache( + disable=self.disable, + prefix=self.prefix + namespace + ".", + cache=self.cache, + ) + +def repeat_concat( + vid: torch.FloatTensor, # (VL ... c) + txt: torch.FloatTensor, # (TL ... c) + vid_len: torch.LongTensor, # (n*b) + txt_len: torch.LongTensor, # (b) + txt_repeat: List, # (n) +) -> torch.FloatTensor: # (L ... c) + vid = torch.split(vid, vid_len.tolist()) + txt = torch.split(txt, txt_len.tolist()) + txt = [[x] * n for x, n in zip(txt, txt_repeat)] + txt = list(chain(*txt)) + return torch.cat(list(chain(*zip(vid, txt)))) + +def repeat_concat_idx( + vid_len: torch.LongTensor, # (n*b) + txt_len: torch.LongTensor, # (b) + txt_repeat: torch.LongTensor, # (n) +) -> Tuple[ + Callable, + Callable, +]: + device = vid_len.device + vid_idx = torch.arange(vid_len.sum(), device=device) + txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) + txt_repeat_list = txt_repeat.tolist() + tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat) + src_idx = torch.argsort(tgt_idx) + txt_idx_len = len(tgt_idx) - len(vid_idx) + repeat_txt_len = (txt_len * txt_repeat).tolist() + + def unconcat_coalesce(all): + vid_out, txt_out = all[src_idx].split([len(vid_idx), txt_idx_len]) + txt_out_coalesced = [] + for txt, repeat_time in zip(txt_out.split(repeat_txt_len), txt_repeat_list): + txt = txt.reshape(-1, repeat_time, *txt.shape[1:]).mean(1) + txt_out_coalesced.append(txt) + return vid_out, torch.cat(txt_out_coalesced) + + return ( + lambda vid, txt: torch.cat([vid, txt])[tgt_idx], + lambda all: unconcat_coalesce(all), + ) + + +@dataclass +class MMArg: + vid: Any + txt: Any + +def safe_pad_operation(x, padding, mode='constant', value=0.0): + try: + return F.pad(x, padding, mode=mode, value=value) + except RuntimeError as e: + if "not implemented for" in str(e) and x.dtype in (torch.float16, torch.bfloat16): + return F.pad(x.float(), padding, mode=mode, value=value).to(x.dtype) + raise + + +def get_args(key: str, args: List[Any]) -> List[Any]: + return [getattr(v, key) if isinstance(v, MMArg) else v for v in args] + + +def get_kwargs(key: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: + return {k: getattr(v, key) if isinstance(v, MMArg) else v for k, v in kwargs.items()} + + +def get_window_op(name: str): + if name == "720pwin_by_size_bysize": + return make_720Pwindows_bysize + if name == "720pswin_by_size_bysize": + return make_shifted_720Pwindows_bysize + raise ValueError(f"Unknown windowing method: {name}") + + +# -------------------------------- Windowing -------------------------------- # +def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): + t, h, w = size + resized_nt, resized_nh, resized_nw = num_windows + #cal windows under 720p + scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w)) + resized_h, resized_w = round(h * scale), round(w * scale) + wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. + wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) # window size. + nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww) # window size. + return [ + ( + slice(it * wt, min((it + 1) * wt, t)), + slice(ih * wh, min((ih + 1) * wh, h)), + slice(iw * ww, min((iw + 1) * ww, w)), + ) + for iw in range(nw) + if min((iw + 1) * ww, w) > iw * ww + for ih in range(nh) + if min((ih + 1) * wh, h) > ih * wh + for it in range(nt) + if min((it + 1) * wt, t) > it * wt + ] + +def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): + t, h, w = size + resized_nt, resized_nh, resized_nw = num_windows + #cal windows under 720p + scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w)) + resized_h, resized_w = round(h * scale), round(w * scale) + wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. + wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) # window size. + + st, sh, sw = ( # shift size. + 0.5 if wt < t else 0, + 0.5 if wh < h else 0, + 0.5 if ww < w else 0, + ) + nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) # window size. + nt, nh, nw = ( # number of window. + nt + 1 if st > 0 else 1, + nh + 1 if sh > 0 else 1, + nw + 1 if sw > 0 else 1, + ) + return [ + ( + slice(max(int((it - st) * wt), 0), min(int((it - st + 1) * wt), t)), + slice(max(int((ih - sh) * wh), 0), min(int((ih - sh + 1) * wh), h)), + slice(max(int((iw - sw) * ww), 0), min(int((iw - sw + 1) * ww), w)), + ) + for iw in range(nw) + if min(int((iw - sw + 1) * ww), w) > max(int((iw - sw) * ww), 0) + for ih in range(nh) + if min(int((ih - sh + 1) * wh), h) > max(int((ih - sh) * wh), 0) + for it in range(nt) + if min(int((it - st + 1) * wt), t) > max(int((it - st) * wt), 0) + ] + +class RotaryEmbedding(nn.Module): + def __init__( + self, + dim, + freqs_for = 'lang', + theta = 10000, + max_freq = 10, + learned_freq = False, + cache_if_possible = True, + cache_max_seq_len = 8192 + ): + super().__init__() + + self.freqs_for = freqs_for + + if freqs_for == 'lang': + freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) + elif freqs_for == 'pixel': + freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi + + self.cache_if_possible = cache_if_possible + self.cache_max_seq_len = cache_max_seq_len + + self.register_buffer('cached_freqs', torch.zeros(cache_max_seq_len, dim), persistent = False) + self.cached_freqs_seq_len = 0 + + self.freqs = nn.Parameter(freqs, requires_grad = learned_freq) + + self.learned_freq = learned_freq + + # dummy for device + + self.register_buffer('dummy', torch.tensor(0), persistent = False) + + @property + def device(self): + return self.dummy.device + + def get_axial_freqs( + self, + *dims, + offsets = None + ): + Colon = slice(None) + all_freqs = [] + + # handle offset + + if exists(offsets): + assert len(offsets) == len(dims) + + for ind, dim in enumerate(dims): + + offset = 0 + if exists(offsets): + offset = offsets[ind] + + if self.freqs_for == 'pixel': + pos = torch.linspace(-1, 1, steps = dim, device = self.device) + else: + pos = torch.arange(dim, device = self.device) + + pos = pos + offset + + freqs = self.forward(pos, seq_len = dim) + + all_axis = [None] * len(dims) + all_axis[ind] = Colon + + new_axis_slice = (Ellipsis, *all_axis, Colon) + all_freqs.append(freqs[new_axis_slice]) + + # concat all freqs + + all_freqs = torch.broadcast_tensors(*all_freqs) + return torch.cat(all_freqs, dim = -1) + + def forward( + self, + t, + seq_len: int | None = None, + offset = 0 + ): + should_cache = ( + self.cache_if_possible and + not self.learned_freq and + exists(seq_len) and + self.freqs_for != 'pixel' and + (offset + seq_len) <= self.cache_max_seq_len + ) + + if ( + should_cache and \ + exists(self.cached_freqs) and \ + (offset + seq_len) <= self.cached_freqs_seq_len + ): + return self.cached_freqs[offset:(offset + seq_len)].detach() + + freqs = self.freqs + + freqs = torch.einsum('..., f -> ... f', t.type(freqs.dtype), freqs) + freqs = einops.repeat(freqs, '... n -> ... (n r)', r = 2) + + if should_cache and offset == 0: + self.cached_freqs[:seq_len] = freqs.detach() + self.cached_freqs_seq_len = seq_len + + return freqs + +class RotaryEmbeddingBase(nn.Module): + def __init__(self, dim: int, rope_dim: int): + super().__init__() + self.rope = RotaryEmbedding( + dim=dim // rope_dim, + freqs_for="pixel", + max_freq=BYTEDANCE_ROPE_MAX_FREQ, + ) + freqs = self.rope.freqs + del self.rope.freqs + self.rope.register_buffer("freqs", freqs.data) + + def get_axial_freqs(self, *dims): + return self.rope.get_axial_freqs(*dims) + + +class RotaryEmbedding3d(RotaryEmbeddingBase): + def __init__(self, dim: int): + super().__init__(dim, rope_dim=3) + self.mm = False + + +class NaRotaryEmbedding3d(RotaryEmbedding3d): + def forward( + self, + q: torch.FloatTensor, + k: torch.FloatTensor, + shape: torch.LongTensor, + cache: Cache, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + freqs = cache("rope_freqs_3d", lambda: self.get_freqs(shape)) + freqs = freqs.to(device=q.device) + q = rearrange(q, "L h d -> h L d") + k = rearrange(k, "L h d -> h L d") + q = _apply_seedvr2_rotary_emb(freqs, q.float()).to(q.dtype) + k = _apply_seedvr2_rotary_emb(freqs, k.float()).to(k.dtype) + q = rearrange(q, "h L d -> L h d") + k = rearrange(k, "h L d -> L h d") + return q, k + + @torch._dynamo.disable + def get_freqs( + self, + shape: torch.LongTensor, + ) -> torch.Tensor: + # Primary provenance: ByteDance-Seed/SeedVR models/dit/rope.py builds + # 7B pixel RoPE with the interleaved-angle convention, not Comfy's + # Flux freqs_cis matrix. + plain_rope = RotaryEmbedding( + dim=self.rope.freqs.numel() * 2, + freqs_for="pixel", + max_freq=BYTEDANCE_ROPE_MAX_FREQ, + ) + plain_rope = plain_rope.to(self.rope.dummy.device) + freq_list = [] + for f, h, w in shape.tolist(): + freqs = plain_rope.get_axial_freqs(f, h, w) + freq_list.append(freqs.view(-1, freqs.size(-1))) + return torch.cat(freq_list, dim=0) + + +class MMRotaryEmbeddingBase(RotaryEmbeddingBase): + def __init__(self, dim: int, rope_dim: int): + super().__init__(dim, rope_dim) + self.rope = RotaryEmbedding( + dim=dim // rope_dim, + freqs_for="lang", + theta=ROPE_THETA, + cache_if_possible=False, + ) + freqs = self.rope.freqs + del self.rope.freqs + self.rope.register_buffer("freqs", freqs.data) + self.mm = True + +def slice_at_dim(t, dim_slice: slice, *, dim): + dim += (t.ndim if dim < 0 else 0) + colons = [slice(None)] * t.ndim + colons[dim] = dim_slice + return t[tuple(colons)] + +# rotary embedding helper functions + +def rotate_half(x): + x = rearrange(x, '... (d r) -> ... d r', r = 2) + x1, x2 = x.unbind(dim = -1) + x = torch.stack((-x2, x1), dim = -1) + return rearrange(x, '... d r -> ... (d r)') +def exists(val): + return val is not None + +def _apply_seedvr2_rotary_emb( + freqs: torch.Tensor, + t: torch.Tensor, + start_index: int = 0, + scale: float = 1.0, + seq_dim: int = -2, + freqs_seq_dim: int | None = None, +) -> torch.Tensor: + dtype = t.dtype + if freqs_seq_dim is None and (freqs.ndim == 2 or t.ndim == 3): + freqs_seq_dim = 0 + + if t.ndim == 3 or freqs_seq_dim is not None: + seq_len = t.shape[seq_dim] + freqs = slice_at_dim(freqs, slice(-seq_len, None), dim=freqs_seq_dim) + + rot_feats = freqs.shape[-1] + end_index = start_index + rot_feats + + t_left = t[..., :start_index] + t_middle = t[..., start_index:end_index] + t_right = t[..., end_index:] + + freqs = freqs.to(device=t_middle.device, dtype=t_middle.dtype) + cos = freqs.cos() * scale + sin = freqs.sin() * scale + t_middle = (t_middle * cos) + (rotate_half(t_middle) * sin) + return torch.cat((t_left, t_middle, t_right), dim=-1).to(dtype) + +def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor: + """Convert lucidrains-interleaved freqs to flux-canonical fp32 freqs_cis `[..., d/2, 2, 2]` (cos/-sin/sin/cos), per `comfy/ldm/flux/math.py:rope`.""" + angles = freqs_interleaved[..., ::2].float() + cos = torch.cos(angles) + sin = torch.sin(angles) + out = torch.stack([cos, -sin, sin, cos], dim=-1) + return rearrange(out, "... d (i j) -> ... d i j", i=2, j=2) + + +def _apply_rope1_partial(t: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + """Rotate the leading ``rot_d = 2 * freqs_cis.shape[-3]`` dims of ``t`` and pass the rest + through; in-place for inference, cloned for training (autograd). Mirrors the legacy + ``apply_rotary_emb`` ``t_left``/``t_middle``/``t_right`` split: 3B ``rope_dim=128`` gives + ``42*3 = 126`` rotated of head_dim 128 (trailing 2 unrotated). Fast path skips the cat when + ``rot_d == t.shape[-1]``. + """ + out = t.clone() if t.requires_grad or comfy.model_management.in_training else t + rot_d = 2 * freqs_cis.shape[-3] + seq_len = out.shape[-2] + for start in range(0, seq_len, SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS): + end = min(start + SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS, seq_len) + freqs_chunk = freqs_cis[start:end] + if rot_d == out.shape[-1]: + out[..., start:end, :] = apply_rope1(out[..., start:end, :], freqs_chunk).to(out.dtype) + else: + out[..., start:end, :rot_d] = apply_rope1(out[..., start:end, :rot_d], freqs_chunk).to(out.dtype) + return out + + +class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): + def __init__(self, dim: int): + super().__init__(dim, rope_dim=3) + + def forward( + self, + vid_q: torch.FloatTensor, # L h d + vid_k: torch.FloatTensor, # L h d + vid_shape: torch.LongTensor, # B 3 + txt_q: torch.FloatTensor, # L h d + txt_k: torch.FloatTensor, # L h d + txt_shape: torch.LongTensor, # B 1 + cache: Cache, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + ]: + vid_freqs, txt_freqs = cache( + "mmrope_freqs_3d", + lambda: self.get_freqs(vid_shape, txt_shape), + ) + target_device = vid_q.device + if vid_freqs.device != target_device: + vid_freqs = vid_freqs.to(target_device) + if txt_freqs.device != target_device: + txt_freqs = txt_freqs.to(target_device) + vid_q = rearrange(vid_q, "L h d -> h L d") + vid_k = rearrange(vid_k, "L h d -> h L d") + vid_q = _apply_rope1_partial(vid_q, vid_freqs) + vid_k = _apply_rope1_partial(vid_k, vid_freqs) + vid_q = rearrange(vid_q, "h L d -> L h d") + vid_k = rearrange(vid_k, "h L d -> L h d") + + txt_q = rearrange(txt_q, "L h d -> h L d") + txt_k = rearrange(txt_k, "L h d -> h L d") + txt_q = _apply_rope1_partial(txt_q, txt_freqs) + txt_k = _apply_rope1_partial(txt_k, txt_freqs) + txt_q = rearrange(txt_q, "h L d -> L h d") + txt_k = rearrange(txt_k, "h L d -> L h d") + return vid_q, vid_k, txt_q, txt_k + + @torch._dynamo.disable # Disable compilation: .tolist() is data-dependent and causes graph breaks + def get_freqs( + self, + vid_shape: torch.LongTensor, + txt_shape: torch.LongTensor, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + ]: + + # Calculate actual max dimensions needed for this batch + max_temporal = 0 + max_height = 0 + max_width = 0 + max_txt_len = 0 + + for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): + max_temporal = max(max_temporal, l + f) # Need up to l+f for temporal + max_height = max(max_height, h) + max_width = max(max_width, w) + max_txt_len = max(max_txt_len, l) + + autocast_device = "cuda" if torch.cuda.is_available() else "cpu" + with torch.amp.autocast(autocast_device, enabled=False): + vid_freqs = self.get_axial_freqs( + max_temporal + 16, + max_height + 4, + max_width + 4, + ).float() + txt_freqs = self.get_axial_freqs(max_txt_len + 16) + + # Now slice as before + vid_freq_list, txt_freq_list = [], [] + for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): + vid_freq = vid_freqs[l : l + f, :h, :w].reshape(-1, vid_freqs.size(-1)) + txt_freq = txt_freqs[:l].repeat(1, 3).reshape(-1, vid_freqs.size(-1)) + vid_freq_list.append(vid_freq) + txt_freq_list.append(txt_freq) + vid_freqs_interleaved = torch.cat(vid_freq_list, dim=0) + txt_freqs_interleaved = torch.cat(txt_freq_list, dim=0) + + # Convert from lucidrains-interleaved layout `[θ0, θ0, θ1, θ1, ...]` + # (produced by `repeat(freqs, '... n -> ... (n r)', r=2)` in the + # upstream `RotaryEmbedding.forward`) to flux-canonical `freqs_cis` + # in shape `[..., d/2, 2, 2]` with `cos/-sin/sin/cos` baked in. + # Mirrors `comfy/ldm/flux/math.py:rope` (line 27) so the trailing + # 2x2 is the per-frequency rotation matrix that + # `comfy.ldm.flux.math.apply_rope1` expects. + return _to_flux_freqs_cis(vid_freqs_interleaved), _to_flux_freqs_cis(txt_freqs_interleaved) + +class MMModule(nn.Module): + def __init__( + self, + module: Callable[..., nn.Module], + *args, + shared_weights: bool = False, + vid_only: bool = False, + **kwargs, + ): + super().__init__() + self.shared_weights = shared_weights + self.vid_only = vid_only + if self.shared_weights: + assert get_args("vid", args) == get_args("txt", args) + assert get_kwargs("vid", kwargs) == get_kwargs("txt", kwargs) + self.all = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) + else: + self.vid = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) + self.txt = ( + module(*get_args("txt", args), **get_kwargs("txt", kwargs)) + if not vid_only + else None + ) + + def forward( + self, + vid: torch.FloatTensor, + txt: torch.FloatTensor, + *args, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + vid_module = self.vid if not self.shared_weights else self.all + vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs)) + if not self.vid_only: + txt_module = self.txt if not self.shared_weights else self.all + txt = txt.to(device=vid.device, dtype=vid.dtype) + txt = txt_module(txt, *get_args("txt", args), **get_kwargs("txt", kwargs)) + return vid, txt + +def get_na_rope(rope_type: Optional[str], dim: int): + if rope_type is None: + return None + if rope_type == "rope3d": + return NaRotaryEmbedding3d(dim=dim) + if rope_type == "mmrope3d": + return NaMMRotaryEmbedding3d(dim=dim) + +class NaMMAttention(nn.Module): + def __init__( + self, + vid_dim: int, + txt_dim: int, + heads: int, + head_dim: int, + qk_bias: bool, + qk_norm, + qk_norm_eps: float, + rope_type: Optional[str], + rope_dim: int, + shared_weights: bool, + device, dtype, operations, + **kwargs, + ): + super().__init__() + dim = MMArg(vid_dim, txt_dim) + self.heads = heads + inner_dim = heads * head_dim + qkv_dim = inner_dim * 3 + self.head_dim = head_dim + self.proj_qkv = MMModule( + operations.Linear, dim, qkv_dim, bias=qk_bias, shared_weights=shared_weights, device=device, dtype=dtype + ) + self.proj_out = MMModule(operations.Linear, inner_dim, dim, shared_weights=shared_weights, device=device, dtype=dtype) + self.norm_q = MMModule( + qk_norm, + normalized_shape=head_dim, + eps=qk_norm_eps, + elementwise_affine=True, + shared_weights=shared_weights, + device=device, dtype=dtype + ) + self.norm_k = MMModule( + qk_norm, + normalized_shape=head_dim, + eps=qk_norm_eps, + elementwise_affine=True, + shared_weights=shared_weights, + device=device, dtype=dtype + ) + + + self.rope = get_na_rope(rope_type=rope_type, dim=rope_dim) + +def window( + hid: torch.FloatTensor, # (L c) + hid_shape: torch.LongTensor, # (b n) + window_fn: Callable[[torch.Tensor], List[torch.Tensor]], +): + hid = unflatten(hid, hid_shape) + hid = list(map(window_fn, hid)) + hid_windows = torch.tensor(list(map(len, hid)), device=hid_shape.device) + hid, hid_shape = flatten(list(chain(*hid))) + return hid, hid_shape, hid_windows + +def window_idx( + hid_shape: torch.LongTensor, # (b n) + window_fn: Callable[[torch.Tensor], List[torch.Tensor]], +): + hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1) + tgt_idx, tgt_shape, tgt_windows = window(hid_idx, hid_shape, window_fn) + tgt_idx = tgt_idx.squeeze(-1) + src_idx = torch.argsort(tgt_idx) + return ( + lambda hid: torch.index_select(hid, 0, tgt_idx), + lambda hid: torch.index_select(hid, 0, src_idx), + tgt_shape, + tgt_windows, + ) + +class NaSwinAttention(NaMMAttention): + def __init__( + self, + *args, + window: Union[int, Tuple[int, int, int]], + window_method: str, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.version_7b = kwargs.get("version", False) + self.window = _triple(window) + self.window_method = window_method + assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window)) + + self.window_op = get_window_op(window_method) + + def forward( + self, + vid: torch.FloatTensor, # l c + txt: torch.FloatTensor, # l c + vid_shape: torch.LongTensor, # b 3 + txt_shape: torch.LongTensor, # b 1 + cache: Cache, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + + vid_qkv, txt_qkv = self.proj_qkv(vid, txt) + + # re-org the input seq for window attn + cache_win = cache.namespace(f"{self.window_method}_{self.window}_sd3") + + def make_window(x: torch.Tensor): + t, h, w, _ = x.shape + window_slices = self.window_op((t, h, w), self.window) + return [x[st, sh, sw] for (st, sh, sw) in window_slices] + + window_partition, window_reverse, window_shape, window_count = cache_win( + "win_transform", + lambda: window_idx(vid_shape, make_window), + ) + vid_qkv_win = window_partition(vid_qkv) + + vid_qkv_win = rearrange(vid_qkv_win, "l (o h d) -> l o h d", o=3, d=self.head_dim) + txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim) + + vid_q, vid_k, vid_v = vid_qkv_win.unbind(1) + txt_q, txt_k, txt_v = txt_qkv.unbind(1) + + vid_q, txt_q = self.norm_q(vid_q, txt_q) + vid_k, txt_k = self.norm_k(vid_k, txt_k) + + txt_len = cache("txt_len", lambda: txt_shape.prod(-1)) + + vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1)) + txt_len = txt_len.to(window_count.device) + + # window rope + if self.rope: + if self.version_7b: + vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) + elif self.rope.mm: + # repeat text q and k for window mmrope + _, num_h, _ = txt_q.shape + txt_q_repeat = rearrange(txt_q, "l h d -> l (h d)") + txt_q_repeat = unflatten(txt_q_repeat, txt_shape) + txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count)] + txt_q_repeat = list(chain(*txt_q_repeat)) + txt_q_repeat, txt_shape_repeat = flatten(txt_q_repeat) + txt_q_repeat = rearrange(txt_q_repeat, "l (h d) -> l h d", h=num_h) + + txt_k_repeat = rearrange(txt_k, "l h d -> l (h d)") + txt_k_repeat = unflatten(txt_k_repeat, txt_shape) + txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count)] + txt_k_repeat = list(chain(*txt_k_repeat)) + txt_k_repeat, _ = flatten(txt_k_repeat) + txt_k_repeat = rearrange(txt_k_repeat, "l (h d) -> l h d", h=num_h) + + vid_q, vid_k, txt_q, txt_k = self.rope( + vid_q, vid_k, window_shape, txt_q_repeat, txt_k_repeat, txt_shape_repeat, cache_win + ) + else: + vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) + + txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count)) + all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win) + concat_win, unconcat_win = cache_win( + "mm_pnp", lambda: repeat_concat_idx(vid_len_win, txt_len, window_count) + ) + out = optimized_var_attention( + q=concat_win(vid_q, txt_q), + k=concat_win(vid_k, txt_k), + v=concat_win(vid_v, txt_v), + heads=self.heads, skip_reshape=True, skip_output_reshape=True, + cu_seqlens_q=cache_win( + "vid_seqlens_q", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() + ), + cu_seqlens_k=cache_win( + "vid_seqlens_k", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() + ), + ) + vid_out, txt_out = unconcat_win(out) + + vid_out = rearrange(vid_out, "l h d -> l (h d)") + txt_out = rearrange(txt_out, "l h d -> l (h d)") + vid_out = window_reverse(vid_out) + + vid_out, txt_out = self.proj_out(vid_out, txt_out) + + return vid_out, txt_out + +class MLP(nn.Module): + def __init__( + self, + dim: int, + expand_ratio: int, + device, dtype, operations + ): + super().__init__() + self.proj_in = operations.Linear(dim, dim * expand_ratio, device=device, dtype=dtype) + self.act = nn.GELU("tanh") + self.proj_out = operations.Linear(dim * expand_ratio, dim, device=device, dtype=dtype) + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + x = self.proj_in(x) + x = self.act(x) + x = self.proj_out(x) + return x + + +class SwiGLUMLP(nn.Module): + def __init__( + self, + dim: int, + expand_ratio: int, + multiple_of: int = 256, + device=None, dtype=None, operations=None + ): + super().__init__() + hidden_dim = int(2 * dim * expand_ratio / 3) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + self.proj_in_gate = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype) + self.proj_out = operations.Linear(hidden_dim, dim, bias=False, device=device, dtype=dtype) + self.proj_in = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype) + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + return self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x)) + +def get_mlp(mlp_type: Optional[str] = "normal"): + # 3b and 7b uses different mlp types + if mlp_type == "normal": + return MLP + elif mlp_type == "swiglu": + return SwiGLUMLP + +class NaMMSRTransformerBlock(nn.Module): + def __init__( + self, + *, + vid_dim: int, + txt_dim: int, + emb_dim: int, + heads: int, + head_dim: int, + expand_ratio: int, + norm, + norm_eps: float, + ada, + qk_bias: bool, + qk_norm, + mlp_type: str, + shared_weights: bool, + rope_type: str, + rope_dim: int, + is_last_layer: bool, + device, dtype, operations, + **kwargs, + ): + super().__init__() + version = kwargs.get("version", False) + dim = MMArg(vid_dim, txt_dim) + self.attn_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, device=device, dtype=dtype) + + self.attn = NaSwinAttention( + vid_dim=vid_dim, + txt_dim=txt_dim, + heads=heads, + head_dim=head_dim, + qk_bias=qk_bias, + qk_norm=qk_norm, + qk_norm_eps=norm_eps, + rope_type=rope_type, + rope_dim=rope_dim, + shared_weights=shared_weights, + window=kwargs.pop("window", None), + window_method=kwargs.pop("window_method", None), + version=version, + device=device, dtype=dtype, operations=operations + ) + + self.mlp_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer, device=device, dtype=dtype) + self.mlp = MMModule( + get_mlp(mlp_type), + dim=dim, + expand_ratio=expand_ratio, + shared_weights=shared_weights, + vid_only=is_last_layer, + device=device, dtype=dtype, operations=operations + ) + self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer, device=device, dtype=dtype) + self.is_last_layer = is_last_layer + self.version = version + + def _seedvr2_7b_mlp( + self, + vid: torch.FloatTensor, + txt: torch.FloatTensor, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + vid_module = self.mlp.vid if not self.mlp.shared_weights else self.mlp.all + if comfy.model_management.in_training or vid.requires_grad: + vid = torch.cat([vid_module(chunk) for chunk in vid.split(SEEDVR2_7B_MLP_CHUNK, dim=0)], dim=0) + else: + vid_out = None + offset = 0 + for chunk in vid.split(SEEDVR2_7B_MLP_CHUNK, dim=0): + chunk_out = vid_module(chunk) + if vid_out is None: + vid_out = chunk_out.new_empty((vid.shape[0], *chunk_out.shape[1:])) + vid_out[offset:offset + chunk_out.shape[0]] = chunk_out + offset += chunk_out.shape[0] + vid = vid_out + if not self.mlp.vid_only: + txt_module = self.mlp.txt if not self.mlp.shared_weights else self.mlp.all + txt = txt.to(device=vid.device, dtype=vid.dtype) + txt = txt_module(txt) + return vid, txt + + def forward( + self, + vid: torch.FloatTensor, # l c + txt: torch.FloatTensor, # l c + vid_shape: torch.LongTensor, # b 3 + txt_shape: torch.LongTensor, # b 1 + emb: torch.FloatTensor, + cache: Cache, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + torch.LongTensor, + torch.LongTensor, + ]: + hid_len = MMArg( + cache("vid_len", lambda: vid_shape.prod(-1)), + cache("txt_len", lambda: txt_shape.prod(-1)), + ) + ada_kwargs = { + "emb": emb, + "hid_len": hid_len, + "cache": cache, + "branch_tag": MMArg("vid", "txt"), + } + + vid_attn, txt_attn = self.attn_norm(vid, txt) + vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs) + vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache) + vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs) + vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt) + + vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn) + vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="in", **ada_kwargs) + if self.version: + vid_mlp, txt_mlp = self._seedvr2_7b_mlp(vid_mlp, txt_mlp) + else: + vid_mlp, txt_mlp = self.mlp(vid_mlp, txt_mlp) + vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="out", **ada_kwargs) + vid_mlp, txt_mlp = (vid_mlp + vid_attn), (txt_mlp + txt_attn) + + return vid_mlp, txt_mlp, vid_shape, txt_shape + +class PatchOut(nn.Module): + def __init__( + self, + out_channels: int, + patch_size: Union[int, Tuple[int, int, int]], + dim: int, + device, dtype, operations + ): + super().__init__() + t, h, w = _triple(patch_size) + self.patch_size = t, h, w + self.proj = operations.Linear(dim, out_channels * t * h * w, device=device, dtype=dtype) + + def forward( + self, + vid: torch.Tensor, + ) -> torch.Tensor: + t, h, w = self.patch_size + vid = self.proj(vid) + vid = rearrange(vid, "b T H W (t h w c) -> b c (T t) (H h) (W w)", t=t, h=h, w=w) + if t > 1: + vid = vid[:, :, (t - 1) :] + return vid + +class NaPatchOut(PatchOut): + def forward( + self, + vid: torch.FloatTensor, # l c + vid_shape: torch.LongTensor, + cache: Cache = Cache(disable=True), # for test + vid_shape_before_patchify = None + ) -> Tuple[ + torch.FloatTensor, + torch.LongTensor, + ]: + + t, h, w = self.patch_size + vid = self.proj(vid) + + if not (t == h == w == 1): + vid = unflatten(vid, vid_shape) + for i in range(len(vid)): + vid[i] = rearrange(vid[i], "T H W (t h w c) -> (T t) (H h) (W w) c", t=t, h=h, w=w) + if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: + vid[i] = vid[i][(t - vid_shape_before_patchify[i, 0] % t) :] + vid, vid_shape = flatten(vid) + + return vid, vid_shape + +class PatchIn(nn.Module): + def __init__( + self, + in_channels: int, + patch_size: Union[int, Tuple[int, int, int]], + dim: int, + device, dtype, operations + ): + super().__init__() + t, h, w = _triple(patch_size) + self.patch_size = t, h, w + self.proj = operations.Linear(in_channels * t * h * w, dim, device=device, dtype=dtype) + + def forward( + self, + vid: torch.Tensor, + ) -> torch.Tensor: + t, h, w = self.patch_size + if t > 1: + assert vid.size(2) % t == 1 + vid = torch.cat([vid[:, :, :1]] * (t - 1) + [vid], dim=2) + vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w) + vid = self.proj(vid) + return vid + +class NaPatchIn(PatchIn): + def forward( + self, + vid: torch.Tensor, # l c + vid_shape: torch.LongTensor, + cache: Cache = Cache(disable=True), # for test + ) -> torch.Tensor: + cache = cache.namespace("patch") + vid_shape_before_patchify = cache("vid_shape_before_patchify", lambda: vid_shape) + t, h, w = self.patch_size + if not (t == h == w == 1): + vid = unflatten(vid, vid_shape) + for i in range(len(vid)): + if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: + vid[i] = torch.cat([vid[i][:1]] * (t - vid[i].size(0) % t) + [vid[i]], dim=0) + vid[i] = rearrange(vid[i], "(T t) (H h) (W w) c -> T H W (t h w c)", t=t, h=h, w=w) + vid, vid_shape = flatten(vid) + + vid = self.proj(vid) + return vid, vid_shape + +def expand_dims(x: torch.Tensor, dim: int, ndim: int): + shape = x.shape + shape = shape[:dim] + (1,) * (ndim - len(shape)) + shape[dim:] + return x.reshape(shape) + + +class AdaSingle(nn.Module): + def __init__( + self, + dim: int, + emb_dim: int, + layers: List[str], + modes: List[str] = ["in", "out"], + device = None, dtype = None, + ): + assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim" + super().__init__() + self.dim = dim + self.emb_dim = emb_dim + self.layers = layers + + param_kwargs = {"device": device} + fp8_types = _torch_float8_types() + if dtype is not None and dtype not in fp8_types: + param_kwargs["dtype"] = dtype + + for l in layers: + if "in" in modes: + self.register_parameter(f"{l}_shift", nn.Parameter(torch.zeros(dim, **param_kwargs))) + self.register_parameter(f"{l}_scale", nn.Parameter(torch.ones(dim, **param_kwargs))) + if "out" in modes: + self.register_parameter(f"{l}_gate", nn.Parameter(torch.zeros(dim, **param_kwargs))) + + def forward( + self, + hid: torch.FloatTensor, # b ... c + emb: torch.FloatTensor, # b d + layer: str, + mode: str, + cache: Cache = Cache(disable=True), + branch_tag: str = "", + hid_len: Optional[torch.LongTensor] = None, # b + ) -> torch.FloatTensor: + idx = self.layers.index(layer) + emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :] + emb = expand_dims(emb, 1, hid.ndim + 1) + + if hid_len is not None: + slice_inputs = lambda x, dim: x + emb = cache( + f"emb_repeat_{idx}_{branch_tag}", + lambda: slice_inputs( + torch.repeat_interleave(emb, hid_len, dim=0), + dim=0, + ), + ) + + shiftA, scaleA, gateA = emb.unbind(-1) + shiftB, scaleB, gateB = ( + getattr(self, f"{layer}_shift", None), + getattr(self, f"{layer}_scale", None), + getattr(self, f"{layer}_gate", None), + ) + + fp8_types = _torch_float8_types() + if fp8_types: + target_dtype = hid.dtype + + if shiftB is not None and shiftB.dtype in fp8_types: + shiftB = shiftB.to(target_dtype) + if scaleB is not None and scaleB.dtype in fp8_types: + scaleB = scaleB.to(target_dtype) + if gateB is not None and gateB.dtype in fp8_types: + gateB = gateB.to(target_dtype) + + if mode == "in": + return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB) + if mode == "out": + if gateB is not None: + return hid.mul_(gateA + gateB) + else: + return hid.mul_(gateA) + + raise NotImplementedError + + +class TimeEmbedding(nn.Module): + def __init__( + self, + sinusoidal_dim: int, + hidden_dim: int, + output_dim: int, + device, dtype, operations + ): + super().__init__() + self.sinusoidal_dim = sinusoidal_dim + self.proj_in = operations.Linear(sinusoidal_dim, hidden_dim, device=device, dtype=dtype) + self.proj_hid = operations.Linear(hidden_dim, hidden_dim, device=device, dtype=dtype) + self.proj_out = operations.Linear(hidden_dim, output_dim, device=device, dtype=dtype) + self.act = nn.SiLU() + + def forward( + self, + timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], + device: torch.device, + dtype: torch.dtype, + ) -> torch.FloatTensor: + if not torch.is_tensor(timestep): + timestep = torch.tensor([timestep], device=device, dtype=dtype) + if timestep.ndim == 0: + timestep = timestep[None] + + emb = get_timestep_embedding( + timesteps=timestep, + embedding_dim=self.sinusoidal_dim, + flip_sin_to_cos=False, + downscale_freq_shift=0, + ).to(dtype) + emb = self.proj_in(emb) + emb = self.act(emb) + emb = self.proj_hid(emb) + emb = self.act(emb) + emb = self.proj_out(emb) + return emb + +def flatten( + hid: List[torch.FloatTensor], # List of (*** c) +) -> Tuple[ + torch.FloatTensor, # (L c) + torch.LongTensor, # (b n) +]: + assert len(hid) > 0 + shape = torch.stack([torch.tensor(x.shape[:-1], device=hid[0].device) for x in hid]) + hid = torch.cat([x.flatten(0, -2) for x in hid]) + return hid, shape + + +def unflatten( + hid: torch.FloatTensor, # (L c) or (L ... c) + hid_shape: torch.LongTensor, # (b n) +) -> List[torch.Tensor]: # List of (*** c) or (*** ... c) + hid_len = hid_shape.prod(-1) + hid = hid.split(hid_len.tolist()) + hid = [x.unflatten(0, s.tolist()) for x, s in zip(hid, hid_shape)] + return hid + +def repeat( + hid: torch.FloatTensor, # (L c) + hid_shape: torch.LongTensor, # (b n) + pattern: str, + **kwargs: Dict[str, torch.LongTensor], # (b) +) -> Tuple[ + torch.FloatTensor, + torch.LongTensor, +]: + hid = unflatten(hid, hid_shape) + kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))] + return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)]) + +class NaDiT(nn.Module): + + def __init__( + self, + norm_eps, + num_layers, + mlp_type, + vid_in_channels = 33, + vid_out_channels = 16, + vid_dim = 2560, + txt_in_dim = 5120, + heads = 20, + head_dim = 128, + mm_layers = 10, + expand_ratio = 4, + qk_bias = False, + patch_size = [ 1,2,2 ], + rope_dim = 128, + rope_type = "mmrope3d", + vid_out_norm: Optional[str] = None, + device = None, + dtype = None, + operations = None, + **kwargs, + ): + self._7b_version = vid_dim == SEEDVR2_7B_VID_DIM + if self._7b_version: + rope_type = "rope3d" + self.dtype = dtype + factory_kwargs = {"device": device, "dtype": dtype} + window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"] + txt_dim = vid_dim + emb_dim = vid_dim * 6 + window = num_layers * [(4,3,3)] + ada = AdaSingle + norm = CustomRMSNorm + qk_norm = CustomRMSNorm + super().__init__() + # ``torch.empty`` returns uninitialized memory, not zeros. The + # SeedVR2Conditioning fail-loud guard at + # ``comfy_extras/nodes_seedvr.py`` distinguishes "buffer was loaded" + # from "buffer was never populated by the file" by checking + # ``positive_conditioning.abs().sum() == 0``. That sentinel is only + # reliable if the post-construction buffer state is deterministically + # zero, so explicitly zero-fill here rather than relying on the + # allocator's zero-on-alloc behavior (allocator-dependent and not + # contractual). When ``load_state_dict`` populates these buffers + # from a properly-baked SeedVR2 .safetensors, the in-place copy + # overwrites the zeros with the universal SeedVR2 conditioning + # tensors (shape (58, 5120) and (64, 5120) bf16). + self.register_buffer("positive_conditioning", torch.zeros((58, 5120), device=device, dtype=dtype)) + self.register_buffer("negative_conditioning", torch.zeros((64, 5120), device=device, dtype=dtype)) + self.vid_in = NaPatchIn( + in_channels=vid_in_channels, + patch_size=patch_size, + dim=vid_dim, + device=device, dtype=dtype, operations=operations + ) + self.txt_in = ( + operations.Linear(txt_in_dim, txt_dim, **factory_kwargs) + if txt_in_dim and txt_in_dim != txt_dim + else nn.Identity() + ) + self.emb_in = TimeEmbedding( + sinusoidal_dim=BYTEDANCE_SINUSOIDAL_DIM, + hidden_dim=max(vid_dim, txt_dim), + output_dim=emb_dim, + device=device, dtype=dtype, operations=operations + ) + + if window is None or isinstance(window[0], int): + window = [window] * num_layers + + rope_dim = rope_dim if rope_dim is not None else head_dim // 2 + self.blocks = nn.ModuleList( + [ + NaMMSRTransformerBlock( + vid_dim=vid_dim, + txt_dim=txt_dim, + emb_dim=emb_dim, + heads=heads, + head_dim=head_dim, + expand_ratio=expand_ratio, + norm=norm, + norm_eps=norm_eps, + ada=ada, + qk_bias=qk_bias, + qk_norm=qk_norm, + mlp_type=mlp_type, + rope_dim = rope_dim, + window=window[i], + window_method=window_method[i], + is_last_layer=(i == num_layers - 1) and not self._7b_version, + rope_type = rope_type, + shared_weights=not ( + (i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i] + ), + version = self._7b_version, + operations = operations, + **kwargs, + **factory_kwargs + ) + for i in range(num_layers) + ] + ) + self.vid_out = NaPatchOut( + out_channels=vid_out_channels, + patch_size=patch_size, + dim=vid_dim, + device=device, dtype=dtype, operations=operations + ) + + self.vid_out_norm = None + if vid_out_norm is not None: + self.vid_out_norm = CustomRMSNorm( + normalized_shape=vid_dim, + eps=norm_eps, + elementwise_affine=True, + device=device, dtype=dtype + ) + self.vid_out_ada = ada( + dim=vid_dim, + emb_dim=emb_dim, + layers=["out"], + modes=["in"], + device=device, dtype=dtype + ) + + def _resolve_text_conditioning(self, context, cond_or_uncond=None): + if context is None or getattr(context, "numel", lambda: None)() == 0: + context = self.positive_conditioning + return flatten([context]) + if NaDiT._seedvr2_is_single_conditioning_branch(cond_or_uncond): + if context.shape[0] == 1: + context = context.squeeze(0) + return flatten([context]) + return flatten(context.unbind(0)) + if context.shape[0] % 2 != 0: + raise ValueError(f"SeedVR2 expected an even text-conditioning batch, got shape {tuple(context.shape)}") + neg_cond, pos_cond = context.chunk(2, dim=0) + if pos_cond.shape[0] == 1: + pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0) + return flatten([pos_cond, neg_cond]) + return flatten((*pos_cond.unbind(0), *neg_cond.unbind(0))) + + @staticmethod + def _seedvr2_is_single_conditioning_branch(cond_or_uncond): + if cond_or_uncond is None or len(cond_or_uncond) == 0: + return False + first = cond_or_uncond[0] + return all(entry == first for entry in cond_or_uncond) + + def _swap_pos_neg_halves(self, out, cond_or_uncond=None): + if NaDiT._seedvr2_is_single_conditioning_branch(cond_or_uncond): + return out + # ``dim=0`` is explicit on both calls. The contract is "split + # the batch axis into two halves and swap them"; making the + # axis load-bearing in source guards against silent drift if a + # future refactor reorders tensor axes. + pos, neg = out.chunk(2, dim=0) + return torch.cat([neg, pos], dim=0) + + def forward( + self, + x, + timestep, + context, # l c + disable_cache: bool = False, # for test # TODO ? // gives an error when set to True + **kwargs + ): + transformer_options = kwargs.get("transformer_options", {}) + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + conditions = kwargs.get("condition") + b, tc, h, w = x.shape + x = x.view(b, 16, -1, h, w) + conditions = conditions.view(b, 17, -1, h, w) + x = x.movedim(1, -1) + conditions = conditions.movedim(1, -1) + cache = Cache(disable=disable_cache) + + txt, txt_shape = self._resolve_text_conditioning(context, transformer_options.get("cond_or_uncond")) + + vid, vid_shape = flatten(x) + cond_latent, _ = flatten(conditions) + + vid = torch.cat([vid, cond_latent], dim=-1) + + txt = self.txt_in(txt) + + vid_shape_before_patchify = vid_shape + vid, vid_shape = self.vid_in(vid, vid_shape, cache=cache) + + emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype) + + for i, block in enumerate(self.blocks): + if ("block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["vid"], out["txt"], out["vid_shape"], out["txt_shape"] = block( + vid=args["vid"], + txt=args["txt"], + vid_shape=args["vid_shape"], + txt_shape=args["txt_shape"], + emb=args["emb"], + cache=args["cache"], + ) + return out + out = blocks_replace[("block", i)]({ + "vid":vid, + "txt":txt, + "vid_shape":vid_shape, + "txt_shape":txt_shape, + "emb":emb, + "cache":cache, + }, {"original_block": block_wrap}) + vid, txt, vid_shape, txt_shape = out["vid"], out["txt"], out["vid_shape"], out["txt_shape"] + else: + vid, txt, vid_shape, txt_shape = block( + vid=vid, + txt=txt, + vid_shape=vid_shape, + txt_shape=txt_shape, + emb=emb, + cache=cache, + ) + + if self.vid_out_norm: + vid = self.vid_out_norm(vid) + vid = self.vid_out_ada( + vid, + emb=emb, + layer="out", + mode="in", + hid_len=cache("vid_len", lambda: vid_shape.prod(-1)), + cache=cache, + branch_tag="vid", + ) + + vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify) + vid = unflatten(vid, vid_shape) + out = torch.stack(vid) + out = out.movedim(-1, 1) + out = rearrange(out, "b c t h w -> b (c t) h w") + return self._swap_pos_neg_halves(out, transformer_options.get("cond_or_uncond")) diff --git a/comfy/model_base.py b/comfy/model_base.py index 2289e0812..6fc306161 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -54,6 +54,7 @@ import comfy.ldm.pixeldit.model import comfy.ldm.pixeldit.pid import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 +import comfy.ldm.seedvr.model import comfy.ldm.qwen_image.model import comfy.ldm.ideogram4.model import comfy.ldm.kandinsky5.model @@ -929,6 +930,16 @@ class HunyuanDiT(BaseModel): out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]])) return out +class SeedVR2(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT) + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + condition = kwargs.get("condition", None) + if condition is not None: + out["condition"] = comfy.conds.CONDRegular(condition) + return out + class PixArt(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 7d0cab308..1128a4d3c 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -598,6 +598,53 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): return dit_config + if "{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix) in state_dict_keys and state_dict["{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix)].shape[1] == 3072: # seedvr2 7b + dit_config = {} + dit_config["image_model"] = "seedvr2" + dit_config["vid_dim"] = 3072 + dit_config["heads"] = 24 + dit_config["num_layers"] = 36 + # 7B uses non-shared MMModule layout (separate ``vid.`` / ``txt.`` + # submodules) at EVERY block — verified by inspecting the 7B + # state_dict at ``blocks.31.ada.txt.attn_gate`` (txt. prefix means + # ``MMModule.shared_weights=False``). Native NaDiT computes + # per-block ``shared_weights = not (i < mm_layers)``, so to keep + # every block non-shared we set ``mm_layers = num_layers``. + # Without this, blocks at index >= mm_layers (default 10) try to + # load ``blocks.N.*.all.*`` keys that don't exist in the file, + # silently miss-load → all-black output. + dit_config["mm_layers"] = 36 + dit_config["norm_eps"] = 1e-5 + dit_config["rope_type"] = "rope3d" + dit_config["rope_dim"] = 64 + dit_config["mlp_type"] = "normal" + return dit_config + elif "{}blocks.35.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b + dit_config = {} + dit_config["image_model"] = "seedvr2" + dit_config["vid_dim"] = 3072 + dit_config["heads"] = 24 + dit_config["num_layers"] = 36 + # This checkpoint layout carries shared ``all.`` MMModule keys. + # Preserve the historical split: the initial blocks use separate + # vid/txt modules, later blocks use shared modules. + dit_config["mm_layers"] = 10 + dit_config["norm_eps"] = 1e-5 + dit_config["rope_type"] = "rope3d" + dit_config["rope_dim"] = 64 + dit_config["mlp_type"] = "swiglu" + return dit_config + elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b + dit_config = {} + dit_config["image_model"] = "seedvr2" + dit_config["vid_dim"] = 2560 + dit_config["heads"] = 20 + dit_config["num_layers"] = 32 + dit_config["norm_eps"] = 1.0e-05 + dit_config["mlp_type"] = "swiglu" + dit_config["vid_out_norm"] = True + return dit_config + if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 dit_config = {} dit_config["image_model"] = "wan2.1" diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 3be935577..5ba6375df 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1683,6 +1683,35 @@ class Chroma(supported_models_base.BASE): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect)) +class SeedVR2(supported_models_base.BASE): + unet_config = { + "image_model": "seedvr2" + } + latent_format = comfy.latent_formats.SeedVR2 + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + sampling_settings = { + "shift": 1.0, + } + + def set_inference_dtype(self, dtype, manual_cast_dtype, device=None): + if ( + dtype == torch.float16 + and manual_cast_dtype is None + and comfy.model_management.should_use_bf16(device) + ): + manual_cast_dtype = torch.bfloat16 + super().set_inference_dtype(dtype, manual_cast_dtype, device=device) + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.SeedVR2(self, device=device) + return out + + def clip_target(self, state_dict={}): + return None + class ChromaRadiance(Chroma): unet_config = { "image_model": "chroma_radiance", @@ -2296,6 +2325,7 @@ models = [ HiDream, HiDreamO1, Chroma, + SeedVR2, ChromaRadiance, ACEStep, ACEStep15, diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 0e7a829ba..572f9984e 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -115,7 +115,7 @@ class BASE: replace_prefix = {"": self.vae_key_prefix[0]} return utils.state_dict_prefix_replace(state_dict, replace_prefix) - def set_inference_dtype(self, dtype, manual_cast_dtype): + def set_inference_dtype(self, dtype, manual_cast_dtype, device=None): self.unet_config['dtype'] = dtype self.manual_cast_dtype = manual_cast_dtype From a7ea0c277380e6b74423e4ccca46ba2f09eecf40 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Thu, 11 Jun 2026 10:39:54 -0500 Subject: [PATCH 02/12] Add SeedVR2 VAE support --- comfy/ldm/seedvr/vae.py | 1807 +++++++++++++++++++++++++++++++++++++++ comfy/sd.py | 125 ++- 2 files changed, 1912 insertions(+), 20 deletions(-) create mode 100644 comfy/ldm/seedvr/vae.py diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py new file mode 100644 index 000000000..3996b9103 --- /dev/null +++ b/comfy/ldm/seedvr/vae.py @@ -0,0 +1,1807 @@ +from contextlib import nullcontext +from typing import Literal, Optional, Tuple +import gc +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch import Tensor +from contextlib import contextmanager +from comfy.utils import ProgressBar + +from comfy.ldm.seedvr.model import safe_pad_operation +from comfy.ldm.seedvr.constants import ( + BYTEDANCE_BLOCK_OUT_CHANNELS, + BYTEDANCE_GN_CHUNKS_FP16, + BYTEDANCE_GN_CHUNKS_FP32, + BYTEDANCE_LOGVAR_CLAMP_MAX, + BYTEDANCE_LOGVAR_CLAMP_MIN, + BYTEDANCE_SLICING_SAMPLE_MIN, + BYTEDANCE_VAE_CONV_MEM_GIB, + BYTEDANCE_VAE_NORM_MEM_GIB, + BYTEDANCE_VAE_SCALING_FACTOR, + BYTEDANCE_VAE_SHIFTING_FACTOR, + BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE, + BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE, + SEEDVR2_LATENT_CHANNELS, +) +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.modules.diffusionmodules.model import vae_attention + +import math +from enum import Enum +from comfy.ops import NVIDIA_MEMORY_CONV_BUG_WORKAROUND + +import logging +import comfy.model_management +import comfy.ops +ops = comfy.ops.disable_weight_init + + +def _seedvr2_temporal_slicing_min_size(temporal_size, temporal_overlap, temporal_scale=1): + if temporal_size is None: + return None + + temporal_size = int(temporal_size) + if temporal_size <= 0: + return 0 + + temporal_overlap = max(0, int(temporal_overlap or 0)) + temporal_overlap = min(temporal_overlap, temporal_size - 1) + temporal_step = temporal_size - temporal_overlap + temporal_scale = max(1, int(temporal_scale)) + return max(1, math.ceil(temporal_step / temporal_scale)) + + +def _seedvr2_clamped_spatial_overlap(overlap, tile_size): + overlap = max(0, int(overlap)) + tile_size = max(1, int(tile_size)) + return min(overlap, tile_size - 1) + + +def _seedvr2_clear_temporal_memory(model): + for module in model.modules(): + if hasattr(module, "memory"): + module.memory = None + + +@torch.inference_mode() +def tiled_vae( + x, + vae_model, + tile_size=(512, 512), + tile_overlap=(64, 64), + temporal_size=16, + temporal_overlap=0, + encode=True, + **kwargs, +): + gc.collect() + comfy.model_management.soft_empty_cache() + + x = x.to(next(vae_model.parameters()).dtype) + if x.ndim != 5: + x = x.unsqueeze(2) + + _, _, d, h, w = x.shape + + sf_s = getattr(vae_model, "spatial_downsample_factor", BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE) + sf_t = getattr(vae_model, "temporal_downsample_factor", BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE) + if encode: + slicing_attr = "slicing_sample_min_size" + slicing_min_size = _seedvr2_temporal_slicing_min_size(temporal_size, temporal_overlap) + else: + slicing_attr = "slicing_latent_min_size" + slicing_min_size = _seedvr2_temporal_slicing_min_size(temporal_size, temporal_overlap, sf_t) + if encode: + ti_h, ti_w = tile_size + ov_h = _seedvr2_clamped_spatial_overlap(tile_overlap[0], ti_h) + ov_w = _seedvr2_clamped_spatial_overlap(tile_overlap[1], ti_w) + blend_ov_h = max(0, ov_h // sf_s) + blend_ov_w = max(0, ov_w // sf_s) + target_d = (d + sf_t - 1) // sf_t + target_h = (h + sf_s - 1) // sf_s + target_w = (w + sf_s - 1) // sf_s + else: + ti_h = max(1, tile_size[0] // sf_s) + ti_w = max(1, tile_size[1] // sf_s) + ov_h = _seedvr2_clamped_spatial_overlap(tile_overlap[0] // sf_s, ti_h) + ov_w = _seedvr2_clamped_spatial_overlap(tile_overlap[1] // sf_s, ti_w) + blend_ov_h = ov_h * sf_s + blend_ov_w = ov_w * sf_s + + target_d = max(1, d * sf_t - (sf_t - 1)) + target_h = h * sf_s + target_w = w * sf_s + + stride_h = max(1, ti_h - ov_h) + stride_w = max(1, ti_w - ov_w) + + storage_device = vae_model.device + result = None + count = None + def run_temporal_chunks(spatial_tile, model=vae_model, device=storage_device): + device = torch.device(device) + _seedvr2_clear_temporal_memory(model) + t_chunk = spatial_tile.to(device=device, dtype=next(model.parameters()).dtype, non_blocking=True).contiguous() + old_device = getattr(model, "device", None) + model.device = device + old_slicing_min_size = getattr(model, slicing_attr, None) + if old_slicing_min_size is not None and slicing_min_size is not None: + if slicing_min_size <= 0: + setattr(model, slicing_attr, t_chunk.shape[2]) + else: + setattr(model, slicing_attr, slicing_min_size) + try: + if encode: + out = model.encode(t_chunk)[0] + else: + out = model.decode_(t_chunk) + finally: + if old_slicing_min_size is not None and slicing_min_size is not None: + setattr(model, slicing_attr, old_slicing_min_size) + if old_device is not None: + model.device = old_device + if isinstance(out, (tuple, list)): + out = out[0] + if out.ndim == 4: + out = out.unsqueeze(2) + return out.to(storage_device) + + ramp_cache = {} + def get_ramp(steps): + if steps not in ramp_cache: + t = torch.linspace(0, 1, steps=steps, device=storage_device, dtype=torch.float32) + ramp_cache[steps] = 0.5 - 0.5 * torch.cos(t * torch.pi) + return ramp_cache[steps] + + tile_ranges = [] + for y_idx in range(0, h, stride_h): + y_end = min(y_idx + ti_h, h) + if y_idx > 0 and (y_end - y_idx) <= ov_h: + continue + for x_idx in range(0, w, stride_w): + x_end = min(x_idx + ti_w, w) + if x_idx > 0 and (x_end - x_idx) <= ov_w: + continue + tile_ranges.append((y_idx, y_end, x_idx, x_end)) + + total_tiles = len(tile_ranges) + bar = ProgressBar(total_tiles) + single_spatial_tile = h <= ti_h and w <= ti_w + + _seedvr2_clear_temporal_memory(vae_model) + + def run_tile(tile_index, tile_range): + y_idx, y_end, x_idx, x_end = tile_range + tile_x = x[:, :, :, y_idx:y_end, x_idx:x_end] + tile_out = run_temporal_chunks(tile_x) + return tile_index, y_idx, y_end, x_idx, x_end, tile_out + + ordered_tile_outputs = ( + run_tile(tile_index, tile_range) + for tile_index, tile_range in enumerate(tile_ranges) + ) + + for _, y_idx, y_end, x_idx, x_end, tile_out in ordered_tile_outputs: + + if single_spatial_tile: + result = tile_out[:, :, :target_d, :target_h, :target_w] + _seedvr2_clear_temporal_memory(vae_model) + if result.device != x.device: + result = result.to(x.device).to(x.dtype) + if x.shape[2] == 1 and sf_t == 1: + result = result.squeeze(2) + bar.update(1) + return result + + if result is None: + b_out, c_out = tile_out.shape[0], tile_out.shape[1] + result = torch.zeros((b_out, c_out, target_d, target_h, target_w), device=storage_device, dtype=torch.float32) + count = torch.zeros((1, 1, 1, target_h, target_w), device=storage_device, dtype=torch.float32) + + if encode: + ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3] + xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4] + cur_ov_h = max(0, min(blend_ov_h, tile_out.shape[3] // 2)) + cur_ov_w = max(0, min(blend_ov_w, tile_out.shape[4] // 2)) + else: + ys, ye = y_idx * sf_s, (y_idx * sf_s) + tile_out.shape[3] + xs, xe = x_idx * sf_s, (x_idx * sf_s) + tile_out.shape[4] + cur_ov_h = max(0, min(blend_ov_h, tile_out.shape[3] // 2)) + cur_ov_w = max(0, min(blend_ov_w, tile_out.shape[4] // 2)) + + w_h = torch.ones((tile_out.shape[3],), device=storage_device) + w_w = torch.ones((tile_out.shape[4],), device=storage_device) + + if cur_ov_h > 0: + r = get_ramp(cur_ov_h) + if y_idx > 0: + w_h[:cur_ov_h] = r + if y_end < h: + w_h[-cur_ov_h:] = 1.0 - r + + if cur_ov_w > 0: + r = get_ramp(cur_ov_w) + if x_idx > 0: + w_w[:cur_ov_w] = r + if x_end < w: + w_w[-cur_ov_w:] = 1.0 - r + + final_weight = w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1) + + valid_d = min(tile_out.shape[2], result.shape[2]) + tile_out = tile_out[:, :, :valid_d, :, :] + + tile_out.mul_(final_weight) + + result[:, :, :valid_d, ys:ye, xs:xe] += tile_out + count[:, :, :, ys:ye, xs:xe] += final_weight + + del tile_out, final_weight, w_h, w_w + bar.update(1) + + result.div_(count.clamp(min=1e-6)) + _seedvr2_clear_temporal_memory(vae_model) + + if result.device != x.device: + result = result.to(x.device).to(x.dtype) + + if x.shape[2] == 1 and sf_t == 1: + result = result.squeeze(2) + + return result + +_NORM_LIMIT = float("inf") +def get_norm_limit(): + return _NORM_LIMIT + + +def set_norm_limit(value: Optional[float] = None): + global _NORM_LIMIT + if value is None: + value = float("inf") + _NORM_LIMIT = value + +@contextmanager +def ignore_padding(model): + orig_padding = model.padding + model.padding = (0, 0, 0) + try: + yield + finally: + model.padding = orig_padding + +class MemoryState(Enum): + DISABLED = 0 + INITIALIZING = 1 + ACTIVE = 2 + UNSET = 3 + +def get_cache_size(conv_module, input_len, pad_len, dim=0): + dilated_kernerl_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1 + output_len = (input_len + pad_len - dilated_kernerl_size) // conv_module.stride[dim] + 1 + remain_len = ( + input_len + pad_len - ((output_len - 1) * conv_module.stride[dim] + dilated_kernerl_size) + ) + overlap_len = dilated_kernerl_size - conv_module.stride[dim] + cache_len = overlap_len + remain_len # >= 0 + + assert output_len > 0 + return cache_len + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters: torch.Tensor): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, BYTEDANCE_LOGVAR_CLAMP_MIN, BYTEDANCE_LOGVAR_CLAMP_MAX) + + def mode(self): + return self.mean + +class SpatialNorm(nn.Module): + def __init__( + self, + f_channels: int, + zq_channels: int, + ): + super().__init__() + self.norm_layer = ops.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) + self.conv_y = ops.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + self.conv_b = ops.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: + f_size = f.shape[-2:] + zq = F.interpolate(zq, size=f_size, mode="nearest") + norm_f = self.norm_layer(f) + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f + +# partial implementation of diffusers's Attention for comfyui +class Attention(nn.Module): + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + kv_heads: Optional[int] = None, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_softmax: bool = False, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + out_dim: int = None, + pre_only=False, + ): + super().__init__() + + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.pre_only = pre_only + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + self.sliceable_head_dim = heads + + self.only_cross_attention = only_cross_attention + + if norm_num_groups is not None: + self.group_norm = ops.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + else: + self.group_norm = None + + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) + else: + self.spatial_norm = None + + self.norm_q = None + self.norm_k = None + + self.norm_cross = None + self.to_q = ops.Linear(query_dim, self.inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = ops.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_v = ops.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + if not self.pre_only: + self.to_out = nn.ModuleList([]) + self.to_out.append(ops.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + else: + self.to_out = None + + self.optimized_vae_attention = vae_attention() + + def __call__( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + + residual = hidden_states + if self.spatial_norm is not None: + hidden_states = self.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + + query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + if self.norm_q is not None: + query = self.norm_q(query) + if self.norm_k is not None: + key = self.norm_k(key) + + if input_ndim == 4 and encoder_hidden_states is hidden_states and attention_mask is None and self.heads == 1: + query = query.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width) + key = key.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width) + value = value.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width) + hidden_states = self.optimized_vae_attention(query, key, value).reshape(batch_size, self.heads, head_dim, height * width).transpose(2, 3) + else: + hidden_states = optimized_attention(query, key, value, heads = self.heads, mask = attention_mask, skip_reshape=True, skip_output_reshape=True) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if self.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / self.rescale_output_factor + + return hidden_states + + +def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: + input_dtype = x.dtype + if isinstance(norm_layer, (ops.LayerNorm, ops.RMSNorm)): + if x.ndim == 4: + x = rearrange(x, "b c h w -> b h w c") + x = norm_layer(x) + x = rearrange(x, "b h w c -> b c h w") + return x.to(input_dtype) + if x.ndim == 5: + x = rearrange(x, "b c t h w -> b t h w c") + x = norm_layer(x) + x = rearrange(x, "b t h w c -> b c t h w") + return x.to(input_dtype) + if isinstance(norm_layer, (ops.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)): + if x.ndim <= 4: + return norm_layer(x).to(input_dtype) + if x.ndim == 5: + t = x.size(2) + x = rearrange(x, "b c t h w -> (b t) c h w") + memory_occupy = x.numel() * x.element_size() / 1024**3 + if isinstance(norm_layer, ops.GroupNorm) and memory_occupy > get_norm_limit(): + num_chunks = min(BYTEDANCE_GN_CHUNKS_FP16 if x.element_size() == 2 else BYTEDANCE_GN_CHUNKS_FP32, norm_layer.num_groups) + assert norm_layer.num_groups % num_chunks == 0 + num_groups_per_chunk = norm_layer.num_groups // num_chunks + + x = list(x.chunk(num_chunks, dim=1)) + weights = norm_layer.weight.chunk(num_chunks, dim=0) + biases = norm_layer.bias.chunk(num_chunks, dim=0) + for i, (w, b) in enumerate(zip(weights, biases)): + x[i] = F.group_norm(x[i], num_groups_per_chunk, w, b, norm_layer.eps) + x[i] = x[i].to(input_dtype) + x = torch.cat(x, dim=1) + else: + x = norm_layer(x) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + return x.to(input_dtype) + raise NotImplementedError + +def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None): + problematic_modes = ['bilinear', 'bicubic', 'trilinear'] + + if mode in problematic_modes: + try: + return F.interpolate( + x, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor + ) + except RuntimeError as e: + if ("not implemented for 'Half'" in str(e) or + "compute_indices_weights" in str(e)): + original_dtype = x.dtype + return F.interpolate( + x.float(), + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor + ).to(original_dtype) + else: + raise e + else: + # Pour 'nearest' et autres modes compatibles, pas de fix nécessaire + return F.interpolate( + x, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor + ) + +_receptive_field_t = Literal["half", "full"] + +def extend_head(tensor, times: int = 2, memory = None): + if memory is not None: + return torch.cat((memory.to(tensor), tensor), dim=2) + assert times >= 0, "Invalid input for function 'extend_head'!" + if times == 0: + return tensor + else: + tile_repeat = [1] * tensor.ndim + tile_repeat[2] = times + return torch.cat(tensors=(torch.tile(tensor[:, :, :1], tile_repeat), tensor), dim=2) + +def cache_send_recv(tensor, cache_size, times, memory=None): + recv_buffer = None + + if memory is not None: + recv_buffer = memory.to(tensor[0]) + elif times > 0: + tile_repeat = [1] * tensor[0].ndim + tile_repeat[2] = times + recv_buffer = torch.tile(tensor[0][:, :, :1], tile_repeat) + + return recv_buffer + +class InflatedCausalConv3d(ops.Conv3d): + def __init__( + self, + *args, + inflation_mode, + **kwargs, + ): + self.inflation_mode = inflation_mode + self.memory = None + super().__init__(*args, **kwargs) + self.temporal_padding = self.padding[0] + self.padding = (0, *self.padding[1:]) + self.memory_limit = float("inf") + self.logged_once = False + + def set_memory_limit(self, value: float): + self.memory_limit = value + + def _conv_forward(self, input, weight, bias, *args, **kwargs): + if (NVIDIA_MEMORY_CONV_BUG_WORKAROUND and + weight.dtype in (torch.float16, torch.bfloat16) and + hasattr(torch.backends.cudnn, 'is_available') and + torch.backends.cudnn.is_available() and + getattr(torch.backends.cudnn, 'enabled', True)): + try: + out = torch.cudnn_convolution( + input, weight, self.padding, self.stride, self.dilation, self.groups, + benchmark=False, deterministic=False, allow_tf32=True + ) + if bias is not None: + out += bias.reshape((1, -1) + (1,) * (out.ndim - 2)) + return out + except RuntimeError: + pass + except NotImplementedError: + pass + try: + return super()._conv_forward(input, weight, bias, *args, **kwargs) + except NotImplementedError: + # for: Could not run 'aten::cudnn_convolution' with arguments from the 'CPU' backend + if not self.logged_once: + logging.warning("VAE is on CPU for decoding. This is most likely due to not enough memory") + self.logged_once = True + return F.conv3d(input, weight, bias, *args, **kwargs) + + def memory_limit_conv( + self, + x, + *, + split_dim=3, + padding=(0, 0, 0, 0, 0, 0), + prev_cache=None, + ): + # Compatible with no limit. + if math.isinf(self.memory_limit): + if prev_cache is not None: + x = torch.cat([prev_cache, x], dim=split_dim - 1) + return super().forward(x) + + # Compute tensor shape after concat & padding. + shape = torch.tensor(x.size()) + if prev_cache is not None: + shape[split_dim - 1] += prev_cache.size(split_dim - 1) + shape[-3:] += torch.tensor(padding).view(3, 2).sum(-1).flip(0) + memory_occupy = shape.prod() * x.element_size() / 1024**3 # GiB + if memory_occupy < self.memory_limit or split_dim == x.ndim: + x_concat = x + if prev_cache is not None: + x_concat = torch.cat([prev_cache, x], dim=split_dim - 1) + + def pad_and_forward(): + padded = safe_pad_operation(x_concat, padding, mode='constant', value=0.0) + if not padded.is_contiguous(): + padded = padded.contiguous() + with ignore_padding(self): + return torch.nn.Conv3d.forward(self, padded) + + return pad_and_forward() + + num_splits = math.ceil(memory_occupy / self.memory_limit) + size_per_split = x.size(split_dim) // num_splits + split_sizes = [size_per_split] * (num_splits - 1) + split_sizes += [x.size(split_dim) - sum(split_sizes)] + + x = list(x.split(split_sizes, dim=split_dim)) + if prev_cache is not None: + prev_cache = list(prev_cache.split(split_sizes, dim=split_dim)) + cache = None + for idx in range(len(x)): + if prev_cache is not None: + x[idx] = torch.cat([prev_cache[idx], x[idx]], dim=split_dim - 1) + + lpad_dim = (x[idx].ndim - split_dim - 1) * 2 + rpad_dim = lpad_dim + 1 + padding = list(padding) + padding[lpad_dim] = self.padding[split_dim - 2] if idx == 0 else 0 + padding[rpad_dim] = self.padding[split_dim - 2] if idx == len(x) - 1 else 0 + pad_len = padding[lpad_dim] + padding[rpad_dim] + padding = tuple(padding) + + next_cache = None + cache_len = cache.size(split_dim) if cache is not None else 0 + next_catch_size = get_cache_size( + conv_module=self, + input_len=x[idx].size(split_dim) + cache_len, + pad_len=pad_len, + dim=split_dim - 2, + ) + if next_catch_size != 0: + assert next_catch_size <= x[idx].size(split_dim) + next_cache = ( + x[idx].transpose(0, split_dim)[-next_catch_size:].transpose(0, split_dim) + ) + + x[idx] = self.memory_limit_conv( + x[idx], + split_dim=split_dim + 1, + padding=padding, + prev_cache=cache + ) + + cache = next_cache + + output = torch.cat(x, dim=split_dim) + return output + + def forward( + self, + input, + memory_state: MemoryState = MemoryState.UNSET + ) -> Tensor: + assert memory_state != MemoryState.UNSET + if memory_state != MemoryState.ACTIVE: + self.memory = None + if ( + math.isinf(self.memory_limit) + and torch.is_tensor(input) + ): + return self.basic_forward(input, memory_state) + return self.slicing_forward(input, memory_state) + + def basic_forward(self, input: Tensor, memory_state: MemoryState = MemoryState.UNSET): + mem_size = self.stride[0] - self.kernel_size[0] + if (self.memory is not None) and (memory_state == MemoryState.ACTIVE): + input = extend_head(input, memory=self.memory, times=-1) + else: + input = extend_head(input, times=self.temporal_padding * 2) + memory = ( + input[:, :, mem_size:].detach() + if (mem_size != 0 and memory_state != MemoryState.DISABLED) + else None + ) + if memory_state != MemoryState.DISABLED: + self.memory = memory + return super().forward(input) + + def slicing_forward( + self, + input, + memory_state: MemoryState = MemoryState.UNSET, + ) -> Tensor: + squeeze_out = False + if torch.is_tensor(input): + input = [input] + squeeze_out = True + + cache_size = self.kernel_size[0] - self.stride[0] + cache = cache_send_recv( + input, cache_size=cache_size, memory=self.memory, times=self.temporal_padding * 2 + ) + + # Single GPU inference - simplified memory management + if ( + memory_state in [MemoryState.INITIALIZING, MemoryState.ACTIVE] # use_slicing + and cache_size != 0 + ): + if cache_size > input[-1].size(2) and cache is not None and len(input) == 1: + input[0] = torch.cat([cache, input[0]], dim=2) + cache = None + if cache_size <= input[-1].size(2): + self.memory = input[-1][:, :, -cache_size:].detach().contiguous() + + padding = tuple(x for x in reversed(self.padding) for _ in range(2)) + for i in range(len(input)): + # Prepare cache for next input slice. + next_cache = None + cache_size = 0 + if i < len(input) - 1: + cache_len = cache.size(2) if cache is not None else 0 + cache_size = get_cache_size(self, input[i].size(2) + cache_len, pad_len=0) + if cache_size != 0: + if cache_size > input[i].size(2) and cache is not None: + input[i] = torch.cat([cache, input[i]], dim=2) + cache = None + assert cache_size <= input[i].size(2), f"{cache_size} > {input[i].size(2)}" + next_cache = input[i][:, :, -cache_size:] + + # Conv forward for this input slice. + input[i] = self.memory_limit_conv( + input[i], + padding=padding, + prev_cache=cache + ) + + # Update cache. + cache = next_cache + + return input[0] if squeeze_out else input + +def remove_head(tensor: Tensor, times: int = 1) -> Tensor: + if times == 0: + return tensor + return torch.cat(tensors=(tensor[:, :, :1], tensor[:, :, times + 1 :]), dim=2) + +class Upsample3D(nn.Module): + + def __init__( + self, + channels, + out_channels = None, + inflation_mode = "tail", + temporal_up: bool = False, + spatial_up: bool = True, + **kwargs, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + conv = InflatedCausalConv3d( + self.channels, + self.out_channels, + 3, + padding=1, + inflation_mode=inflation_mode, + ) + + self.temporal_up = temporal_up + self.spatial_up = spatial_up + self.temporal_ratio = 2 if temporal_up else 1 + self.spatial_ratio = 2 if spatial_up else 1 + + # [Override] MAGViT v2 learnable upsample + upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio + self.upscale_conv = ops.Conv3d( + self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0 + ) + identity = ( + torch.eye(self.channels) + .repeat(upscale_ratio, 1) + .reshape_as(self.upscale_conv.weight) + ) + self.upscale_conv.weight.data.copy_(identity) + + self.conv = conv + + def forward( + self, + hidden_states: torch.FloatTensor, + memory_state=None, + **kwargs, + ) -> torch.FloatTensor: + assert hidden_states.shape[1] == self.channels + + hidden_states = self.upscale_conv(hidden_states) + hidden_states = rearrange( + hidden_states, + "b (x y z c) f h w -> b c (f z) (h x) (w y)", + x=self.spatial_ratio, + y=self.spatial_ratio, + z=self.temporal_ratio, + ) + + if self.temporal_up and memory_state != MemoryState.ACTIVE: + hidden_states = remove_head(hidden_states) + + hidden_states = self.conv(hidden_states, memory_state=memory_state) + + return hidden_states + + +class Downsample3D(nn.Module): + """A 3D downsampling layer with an optional convolution.""" + + def __init__( + self, + channels, + out_channels = None, + inflation_mode = "tail", + spatial_down: bool = False, + temporal_down: bool = False, + **kwargs, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.temporal_down = temporal_down + self.spatial_down = spatial_down + + self.temporal_ratio = 2 if temporal_down else 1 + self.spatial_ratio = 2 if spatial_down else 1 + + self.temporal_kernel = 3 if temporal_down else 1 + self.spatial_kernel = 3 if spatial_down else 1 + + self.conv = InflatedCausalConv3d( + self.channels, + self.out_channels, + kernel_size=(self.temporal_kernel, self.spatial_kernel, self.spatial_kernel), + stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), + padding=(1 if self.temporal_down else 0, 0, 0), + inflation_mode=inflation_mode, + ) + + + def forward( + self, + hidden_states: torch.FloatTensor, + memory_state = None, + **kwargs, + ) -> torch.FloatTensor: + + assert hidden_states.shape[1] == self.channels + + if hasattr(self, "norm") and self.norm is not None: + # [Overridden] change to causal norm. + hidden_states = causal_norm_wrapper(self.norm, hidden_states) + + if self.spatial_down: + pad = (0, 1, 0, 1) + hidden_states = safe_pad_operation(hidden_states, pad, mode="constant", value=0) + + assert hidden_states.shape[1] == self.channels + + hidden_states = self.conv(hidden_states, memory_state=memory_state) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + groups_out: Optional[int] = None, + eps: float = 1e-6, + output_scale_factor: float = 1.0, + skip_time_act: bool = False, + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "half", + **kwargs, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.output_scale_factor = output_scale_factor + self.skip_time_act = skip_time_act + self.nonlinearity = nn.SiLU() + if temb_channels is not None: + self.time_emb_proj = ops.Linear(temb_channels, out_channels) + else: + self.time_emb_proj = None + self.norm1 = ops.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + if groups_out is None: + groups_out = groups + self.norm2 = ops.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.use_in_shortcut = self.in_channels != out_channels + self.dropout = torch.nn.Dropout(dropout) + self.conv1 = InflatedCausalConv3d( + self.in_channels, + self.out_channels, + kernel_size=(1, 3, 3) if time_receptive_field == "half" else (3, 3, 3), + stride=1, + padding=(0, 1, 1) if time_receptive_field == "half" else (1, 1, 1), + inflation_mode=inflation_mode, + ) + + self.conv2 = InflatedCausalConv3d( + self.out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + inflation_mode=inflation_mode, + ) + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = InflatedCausalConv3d( + self.in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=True, + inflation_mode=inflation_mode, + ) + + def forward( + self, input_tensor, temb, memory_state = None, **kwargs + ): + hidden_states = input_tensor + + hidden_states = causal_norm_wrapper(self.norm1, hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states, memory_state=memory_state) + + if self.time_emb_proj is not None: + if not self.skip_time_act: + temb = self.nonlinearity(temb) + temb = self.time_emb_proj(temb)[:, :, None, None] + + if temb is not None: + hidden_states = hidden_states + temb + + hidden_states = causal_norm_wrapper(self.norm2, hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states, memory_state=memory_state) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor, memory_state=memory_state) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +class DownEncoderBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "half", + temporal_down: bool = True, + spatial_down: bool = True, + ): + super().__init__() + resnets = [] + temporal_modules = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + # [Override] Replace module. + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + output_scale_factor=output_scale_factor, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + ) + temporal_modules.append(nn.Identity()) + + self.resnets = nn.ModuleList(resnets) + self.temporal_modules = nn.ModuleList(temporal_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, + out_channels=out_channels, + temporal_down=temporal_down, + spatial_down=spatial_down, + inflation_mode=inflation_mode, + ) + ] + ) + else: + self.downsamplers = None + + def forward( + self, + hidden_states: torch.FloatTensor, + memory_state = None, + **kwargs, + ) -> torch.FloatTensor: + for resnet, temporal in zip(self.resnets, self.temporal_modules): + hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state) + hidden_states = temporal(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, memory_state=memory_state) + + return hidden_states + + +class UpDecoderBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temb_channels: Optional[int] = None, + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "half", + temporal_up: bool = True, + spatial_up: bool = True, + ): + super().__init__() + resnets = [] + temporal_modules = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + # [Override] Replace module. + ResnetBlock3D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + output_scale_factor=output_scale_factor, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + ) + + temporal_modules.append(nn.Identity()) + + self.resnets = nn.ModuleList(resnets) + self.temporal_modules = nn.ModuleList(temporal_modules) + + if add_upsample: + # [Override] Replace module & use learnable upsample + self.upsamplers = nn.ModuleList( + [ + Upsample3D( + out_channels, + out_channels=out_channels, + temporal_up=temporal_up, + spatial_up=spatial_up, + inflation_mode=inflation_mode, + ) + ] + ) + else: + self.upsamplers = None + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + memory_state=None + ) -> torch.FloatTensor: + for resnet, temporal in zip(self.resnets, self.temporal_modules): + hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state) + hidden_states = temporal(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, memory_state=memory_state) + + return hidden_states + + +class UNetMidBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_groups: int = 32, + add_attention: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "half", + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention + + # there is always at least one resnet + resnets = [ + # [Override] Replace module. + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + output_scale_factor=output_scale_factor, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + ] + attentions = [] + + if attention_head_dim is None: + attention_head_dim = in_channels + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=( + resnet_groups if resnet_time_scale_shift == "default" else None + ), + spatial_norm_dim=( + temb_channels if resnet_time_scale_shift == "spatial" else None + ), + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + output_scale_factor=output_scale_factor, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None, memory_state=None): + video_length, frame_height, frame_width = hidden_states.size()[-3:] + hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + hidden_states = attn(hidden_states, temb=temb) + hidden_states = rearrange( + hidden_states, "(b f) c h w -> b c f h w", f=video_length + ) + hidden_states = resnet(hidden_states, temb, memory_state=memory_state) + + return hidden_states + + +class Encoder3D(nn.Module): + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ("DownEncoderBlock3D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + mid_block_add_attention=True, + # [Override] add temporal down num + temporal_down_num: int = 2, + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "half", + ): + super().__init__() + self.layers_per_block = layers_per_block + self.temporal_down_num = temporal_down_num + + self.conv_in = InflatedCausalConv3d( + in_channels, + block_out_channels[0], + kernel_size=3, + stride=1, + padding=1, + inflation_mode=inflation_mode, + ) + + self.mid_block = None + self.down_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + # [Override] to support temporal down block design + is_temporal_down_block = i >= len(block_out_channels) - self.temporal_down_num - 1 + # Note: take the last ones + + assert down_block_type == "DownEncoderBlock3D" + + down_block = DownEncoderBlock3D( + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + temporal_down=is_temporal_down_block, + spatial_down=True, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + output_scale_factor=1, + resnet_time_scale_shift="default", + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=None, + add_attention=mid_block_add_attention, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + + # out + self.conv_norm_out = ops.GroupNorm( + num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6 + ) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels + self.conv_out = InflatedCausalConv3d( + block_out_channels[-1], conv_out_channels, 3, padding=1, inflation_mode=inflation_mode + ) + + + def forward( + self, + sample: torch.FloatTensor, + memory_state = None + ) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" + sample = sample.to(next(self.parameters()).device) + sample = self.conv_in(sample, memory_state = memory_state) + # down + for down_block in self.down_blocks: + sample = down_block(sample, memory_state=memory_state) + + # middle + sample = self.mid_block(sample, memory_state=memory_state) + + # post-process + sample = causal_norm_wrapper(self.conv_norm_out, sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample, memory_state = memory_state) + + return sample + + +class Decoder3D(nn.Module): + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ("UpDecoderBlock3D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + mid_block_add_attention=True, + # [Override] add temporal up block + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "half", + temporal_up_num: int = 2, + ): + super().__init__() + self.layers_per_block = layers_per_block + self.temporal_up_num = temporal_up_num + + self.conv_in = InflatedCausalConv3d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + inflation_mode=inflation_mode, + ) + + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + temb_channels = None + + # mid + self.mid_block = UNetMidBlock3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + output_scale_factor=1, + resnet_time_scale_shift="default", + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + add_attention=mid_block_add_attention, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + is_temporal_up_block = i < self.temporal_up_num + assert up_block_type == "UpDecoderBlock3D" + up_block = UpDecoderBlock3D( + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + temporal_up=is_temporal_up_block, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = ops.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 + ) + self.conv_act = nn.SiLU() + self.conv_out = InflatedCausalConv3d( + block_out_channels[0], out_channels, 3, padding=1, inflation_mode=inflation_mode + ) + + + # Note: Just copy from Decoder. + def forward( + self, + sample: torch.FloatTensor, + latent_embeds: Optional[torch.FloatTensor] = None, + memory_state = None, + ) -> torch.FloatTensor: + + sample = sample.to(next(self.parameters()).device) + sample = self.conv_in(sample, memory_state=memory_state) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + # middle + sample = self.mid_block(sample, latent_embeds, memory_state=memory_state) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = up_block(sample, latent_embeds, memory_state=memory_state) + + # post-process + sample = causal_norm_wrapper(self.conv_norm_out, sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample, memory_state=memory_state) + + return sample + +class VideoAutoencoderKL(nn.Module): + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + layers_per_block: int = 2, + latent_channels: int = SEEDVR2_LATENT_CHANNELS, + norm_num_groups: int = 32, + temporal_scale_num: int = 2, + inflation_mode = "pad", + time_receptive_field: _receptive_field_t = "full", + slicing_sample_min_size = BYTEDANCE_SLICING_SAMPLE_MIN, + *args, + **kwargs, + ): + self.slicing_sample_min_size = slicing_sample_min_size + self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num) + block_out_channels = BYTEDANCE_BLOCK_OUT_CHANNELS + down_block_types = ("DownEncoderBlock3D",) * 4 + up_block_types = ("UpDecoderBlock3D",) * 4 + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder3D( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + # [Override] add temporal_down_num parameter + temporal_down_num=temporal_scale_num, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + + # pass init params to Decoder + self.decoder = Decoder3D( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + # [Override] add temporal_up_num parameter + temporal_up_num=temporal_scale_num, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + + self.use_slicing = True + + def encode(self, x: torch.FloatTensor, return_dict: bool = True): + h = self.slicing_encode(x) + posterior = DiagonalGaussianDistribution(h).mode() + + if not return_dict: + return (posterior,) + + return posterior + + def decode_( + self, z: torch.Tensor, return_dict: bool = True + ): + decoded = self.slicing_decode(z) + + if not return_dict: + return (decoded,) + + return decoded + + def _encode( + self, x, memory_state = MemoryState.DISABLED + ) -> torch.Tensor: + _x = x.to(self.device) + h = self.encoder(_x, memory_state=memory_state) + return h.to(x.device) + + def _decode( + self, z, memory_state = MemoryState.DISABLED + ) -> torch.Tensor: + _z = z.to(self.device) + output = self.decoder(_z, memory_state=memory_state) + return output.to(z.device) + + def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: + sp_size =1 + if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size: + split_size = max( + self.slicing_sample_min_size * sp_size, + getattr(self, "temporal_downsample_factor", 1), + ) + x_slices = list(x[:, :, 1:].split(split_size=split_size, dim=2)) + min_active_len = getattr(self, "temporal_downsample_factor", 1) + if len(x_slices) > 1 and x_slices[-1].shape[2] < min_active_len: + x_slices[-2] = torch.cat((x_slices[-2], x_slices[-1]), dim=2) + x_slices.pop() + encoded_slices = [ + self._encode( + torch.cat((x[:, :, :1], x_slices[0]), dim=2), + memory_state=MemoryState.INITIALIZING, + ) + ] + for x_idx in range(1, len(x_slices)): + encoded_slices.append( + self._encode(x_slices[x_idx], memory_state=MemoryState.ACTIVE) + ) + out = torch.cat(encoded_slices, dim=2) + modules_with_memory = [m for m in self.modules() + if isinstance(m, InflatedCausalConv3d) and m.memory is not None] + for m in modules_with_memory: + m.memory = None + return out + else: + return self._encode(x) + + def slicing_decode(self, z: torch.Tensor) -> torch.Tensor: + sp_size = 1 + if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size: + z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2) + decoded_slices = [ + self._decode( + torch.cat((z[:, :, :1], z_slices[0]), dim=2), + memory_state=MemoryState.INITIALIZING + ) + ] + for z_idx in range(1, len(z_slices)): + decoded_slices.append( + self._decode(z_slices[z_idx], memory_state=MemoryState.ACTIVE) + ) + out = torch.cat(decoded_slices, dim=2) + modules_with_memory = [m for m in self.modules() + if isinstance(m, InflatedCausalConv3d) and m.memory is not None] + for m in modules_with_memory: + m.memory = None + return out + else: + return self._decode(z) + + def forward( + self, x: torch.FloatTensor, mode: Literal["encode", "decode", "all"] = "all", **kwargs + ): + # x: [b c t h w] + def _unwrap(value): + return value[0] if isinstance(value, tuple) else value + + if mode == "encode": + return _unwrap(self.encode(x)) + elif mode == "decode": + return _unwrap(self.decode_(x)) + else: + latent = _unwrap(self.encode(x)) + return _unwrap(self.decode_(latent)) + +class VideoAutoencoderKLWrapper(VideoAutoencoderKL): + # Signals to comfy.sd.VAE that this model performs its own VAE tiling, so the + # generic tiled-decode/encode dispatch defers to decode_tiled/encode_tiled below. + comfy_handles_tiling = True + + def __init__( + self, + *args, + spatial_downsample_factor = 8, + temporal_downsample_factor = 4, + freeze_encoder = True, + **kwargs, + ): + self.spatial_downsample_factor = spatial_downsample_factor + self.temporal_downsample_factor = temporal_downsample_factor + self.freeze_encoder = freeze_encoder + self.enable_tiling = False + super().__init__(*args, **kwargs) + self.set_memory_limit(BYTEDANCE_VAE_CONV_MEM_GIB, BYTEDANCE_VAE_NORM_MEM_GIB) + + def forward(self, x: torch.FloatTensor): + with torch.no_grad() if self.freeze_encoder else nullcontext(): + z, p = self.encode(x) + x = self.decode(z) + return x, z, p + + def encode(self, x, orig_dims=None): + if x.ndim == 4: + x = x.unsqueeze(2) + x = x.to(dtype=next(self.parameters()).dtype) + self.device = x.device + p = super().encode(x) + z = p.squeeze(2) + return z, p + + def decode(self, z, seedvr2_tiling=None): + seedvr2_tiling = {} if seedvr2_tiling is None else seedvr2_tiling + if not isinstance(seedvr2_tiling, dict): + raise RuntimeError( + "SeedVR2 VideoAutoencoderKLWrapper.decode: `seedvr2_tiling` must be a dict; " + f"got {type(seedvr2_tiling).__name__} with value {seedvr2_tiling!r}." + ) + + if z.ndim == 5: + b, c, t_latent, h, w = z.shape + if c != 16: + raise RuntimeError( + "SeedVR2 VideoAutoencoderKLWrapper.decode: 5-D latent input must " + f"have 16 channels; got shape {tuple(z.shape)}." + ) + latent = z + elif z.ndim == 4: + b, tc, h, w = z.shape + if tc % 16 != 0: + raise RuntimeError( + "SeedVR2 VideoAutoencoderKLWrapper.decode: 4-D latent input must " + "use collapsed channel layout (B, 16*T, H, W); " + f"got shape {tuple(z.shape)}." + ) + latent = z.reshape(b, 16, -1, h, w) + else: + raise RuntimeError( + "SeedVR2 VideoAutoencoderKLWrapper.decode: latent input must be " + "4-D collapsed (B, 16*T, H, W) or 5-D (B, 16, T, H, W); " + f"got shape {tuple(z.shape)}." + ) + scale = BYTEDANCE_VAE_SCALING_FACTOR + shift = BYTEDANCE_VAE_SHIFTING_FACTOR + latent = latent / scale + shift + + self.device = latent.device + self.enable_tiling = seedvr2_tiling.get("enable_tiling", False) + + if self.enable_tiling: + decode_seedvr2_args = dict(seedvr2_tiling) + tile_h, tile_w = decode_seedvr2_args.get("tile_size", (512, 512)) + ov_h, ov_w = decode_seedvr2_args.get("tile_overlap", (64, 64)) + decode_seedvr2_args["tile_overlap"] = ( + min(ov_h, max(0, tile_h - 8)), + min(ov_w, max(0, tile_w - 8)), + ) + x = tiled_vae(latent, self, **decode_seedvr2_args, encode=False) + if x.ndim == 4: + # tiled_vae squeezes the temporal axis when + # temporal_downsample_factor == 1 AND latent T == 1 + # (see tiled_vae line 179-180); re-add it so the post-decode + # pipeline can keep batch and time distinct on the tiled path. + x = x.unsqueeze(2) + else: + x = super().decode_(latent) + + # ensure even dims for save video + h, w = x.shape[-2:] + w2 = w - (w % 2) + h2 = h - (h % 2) + x = x[..., :h2, :w2] + + return x + + def decode_tiled(self, z, tile_x=32, tile_y=32, overlap=8, tile_t=None, overlap_t=None): + # SeedVR2's causal VAE owns temporal via the MemoryState cache; temporal + # slicing breaks that continuity (empirically corrupts decode), so the VAE + # tiling knobs (tile_t / overlap_t) are discarded and temporal stays whole. + sf = self.spatial_downsample_factor + seedvr2_tiling = { + "enable_tiling": True, + "tile_size": (tile_y * sf, tile_x * sf), + "tile_overlap": (overlap * sf, overlap * sf), + "temporal_size": 0, + "temporal_overlap": 0, + } + return self.decode(z, seedvr2_tiling=seedvr2_tiling) + + def encode_tiled(self, x, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): + # Temporal tiling knobs are discarded; the causal VAE owns temporal (slicing + # breaks MemoryState continuity), so temporal stays whole. + if tile_y is None: + tile_y = 512 + if tile_x is None: + tile_x = 512 + if overlap is None: + overlap_y = 64 + overlap_x = 64 + else: + overlap_y = overlap + overlap_x = overlap + overlap_y = min(overlap_y, max(0, tile_y - 8)) + overlap_x = min(overlap_x, max(0, tile_x - 8)) + self.device = x.device + return tiled_vae( + x, + self, + tile_size=(tile_y, tile_x), + tile_overlap=(overlap_y, overlap_x), + temporal_size=0, + temporal_overlap=0, + encode=True, + ) + + def comfy_format_encoded(self, samples): + if samples.ndim == 4: + samples = samples.unsqueeze(2) + samples = samples.contiguous() + samples = samples * 0.9152 + return samples + + def comfy_memory_used_decode(self, shape): + bytes_per_output_pixel = 160 + + def output_pixels(latent_t, latent_h, latent_w): + output_t = max(1, (latent_t - 1) * 4 + 1) + return output_t * latent_h * 8 * latent_w * 8 + + # SeedVR2 decode performs full-frame LAB histogram matching: fp32 channels + # plus int64 sort indices dominate peak memory, not the VAE weight dtype. + if len(shape) == 5: + candidates = [] + if shape[1] == 16: + candidates.append((shape[2], shape[3], shape[4])) + if shape[-1] == 16: + candidates.append((shape[1], shape[2], shape[3])) + if len(candidates) == 0: + candidates.append((shape[2], shape[3], shape[4])) + pixels = max(output_pixels(*candidate) for candidate in candidates) + elif len(shape) == 4: + latent_t = max(1, (shape[1] + 15) // 16) + pixels = output_pixels(latent_t, shape[2], shape[3]) + else: + pixels = output_pixels(1, shape[-2], shape[-1]) + return pixels * bytes_per_output_pixel + + def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]): + set_norm_limit(norm_max_mem) + for m in self.modules(): + if isinstance(m, InflatedCausalConv3d): + m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf")) diff --git a/comfy/sd.py b/comfy/sd.py index a66ba1bfb..64a4a58f1 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,3 +1,4 @@ +import inspect import json import torch from enum import Enum @@ -16,6 +17,7 @@ import comfy.ldm.cosmos.vae import comfy.ldm.wan.vae import comfy.ldm.wan.vae2_2 import comfy.ldm.hunyuan3d.vae +import comfy.ldm.seedvr.vae import comfy.ldm.triposplat.vae import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.cogvideo.vae @@ -467,8 +469,10 @@ class CLIP: class VAE: def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None): - if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format - sd = diffusers_convert.convert_vae_state_dict(sd) + is_seedvr2_vae = "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd + if not is_seedvr2_vae and 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format + if metadata is None or metadata.get("keep_diffusers_format") != "true": + sd = diffusers_convert.convert_vae_state_dict(sd) if model_management.is_amd(): VAE_KL_MEM_RATIO = 2.73 @@ -540,6 +544,20 @@ class VAE: self.first_stage_model = StageC_coder() self.downscale_ratio = 32 self.latent_channels = 16 + elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2 + self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper() + self.latent_channels = 16 + self.latent_dim = 3 + self.disable_offload = True + self.memory_used_decode = lambda shape, dtype: self.first_stage_model.comfy_memory_used_decode(shape) + self.memory_used_encode = lambda shape, dtype: (max(shape[2], 5) * shape[3] * shape[4] * 64) * model_management.dtype_size(dtype) + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) + self.downscale_index_formula = (4, 8, 8) + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) + self.upscale_index_formula = (4, 8, 8) + self.process_input = lambda image: image * 2.0 - 1.0 + self.crop_input = False elif "decoder.conv_in.weight" in sd: if sd['decoder.conv_in.weight'].shape[1] == 64: ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True} @@ -1006,6 +1024,10 @@ class VAE: decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device)) + def _decode_tiled_owned(self, samples, **kwargs): + out = self.first_stage_model.decode_tiled(samples.to(self.vae_dtype).to(self.device), **kwargs) + return self.process_output(out.to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)) + def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) @@ -1042,6 +1064,11 @@ class VAE: encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) + def _encode_tiled_owned(self, pixel_samples, **kwargs): + x = self.process_input(pixel_samples).to(self.vae_dtype).to(self.device) + out = self.first_stage_model.encode_tiled(x, **kwargs) + return out.to(device=self.output_device, dtype=self.vae_output_dtype()) + def decode(self, samples_in, vae_options={}): self.throw_exception_if_invalid() pixel_samples = None @@ -1089,11 +1116,19 @@ class VAE: if dims == 1 or self.extra_1d_channel is not None: pixel_samples = self.decode_tiled_1d(samples_in) elif dims == 2: - pixel_samples = self.decode_tiled_(samples_in) + if getattr(self.first_stage_model, "comfy_handles_tiling", False): + tile = 256 // self.spacial_compression_decode() + overlap = tile // 4 + pixel_samples = self._decode_tiled_owned(samples_in, tile_x=tile, tile_y=tile, overlap=overlap) + else: + pixel_samples = self.decode_tiled_(samples_in) elif dims == 3: tile = 256 // self.spacial_compression_decode() overlap = tile // 4 - pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) + if getattr(self.first_stage_model, "comfy_handles_tiling", False): + pixel_samples = self._decode_tiled_owned(samples_in, tile_x=tile, tile_y=tile, overlap=overlap) + else: + pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) return pixel_samples @@ -1112,7 +1147,20 @@ class VAE: args["overlap"] = overlap with model_management.cuda_device_context(self.device): - if dims == 1 or self.extra_1d_channel is not None: + if getattr(self.first_stage_model, "comfy_handles_tiling", False) and dims in (2, 3): + tiled_args = {} + if tile_x is not None: + tiled_args["tile_x"] = tile_x + if tile_y is not None: + tiled_args["tile_y"] = tile_y + if overlap is not None: + tiled_args["overlap"] = overlap + if tile_t is not None: + tiled_args["tile_t"] = tile_t + if overlap_t is not None: + tiled_args["overlap_t"] = overlap_t + output = self._decode_tiled_owned(samples, **tiled_args) + elif dims == 1 or self.extra_1d_channel is not None: args.pop("tile_y") output = self.decode_tiled_1d(samples, **args) elif dims == 2: @@ -1154,6 +1202,8 @@ class VAE: else: pixels_in = pixels_in.to(self.device) out = self.first_stage_model.encode(pixels_in) + if isinstance(out, tuple): + out = out[0] out = out.to(self.output_device).to(dtype=self.vae_output_dtype()) if samples is None: samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) @@ -1173,12 +1223,18 @@ class VAE: if self.latent_dim == 3: tile = 256 overlap = tile // 4 - samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) + if getattr(self.first_stage_model, "comfy_handles_tiling", False): + samples = self._encode_tiled_owned(pixel_samples, tile_x=tile, tile_y=tile, overlap=overlap) + else: + samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) elif self.latent_dim == 1 or self.extra_1d_channel is not None: samples = self.encode_tiled_1d(pixel_samples) else: samples = self.encode_tiled_(pixel_samples) + formatter = getattr(self.first_stage_model, "comfy_format_encoded", None) + if formatter is not None: + samples = formatter(samples) return samples def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): @@ -1186,7 +1242,7 @@ class VAE: pixel_samples = self.vae_encode_crop_pixels(pixel_samples) dims = self.latent_dim pixel_samples = pixel_samples.movedim(-1, 1) - if dims == 3: + if dims == 3 and pixel_samples.ndim < 5: if not self.not_video: pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) else: @@ -1210,21 +1266,39 @@ class VAE: elif dims == 2: samples = self.encode_tiled_(pixel_samples, **args) elif dims == 3: - if tile_t is not None: - tile_t_latent = max(2, self.downscale_ratio[0](tile_t)) + if getattr(self.first_stage_model, "comfy_handles_tiling", False): + tiled_args = {} + if tile_x is not None: + tiled_args["tile_x"] = tile_x + if tile_y is not None: + tiled_args["tile_y"] = tile_y + if overlap is not None: + tiled_args["overlap"] = overlap + if tile_t is not None: + tiled_args["tile_t"] = tile_t + if overlap_t is not None: + tiled_args["overlap_t"] = overlap_t + samples = self._encode_tiled_owned(pixel_samples, **tiled_args) else: - tile_t_latent = 9999 - args["tile_t"] = self.upscale_ratio[0](tile_t_latent) + if tile_t is not None: + tile_t_latent = max(2, self.downscale_ratio[0](tile_t)) + else: + tile_t_latent = 9999 + args["tile_t"] = self.upscale_ratio[0](tile_t_latent) - if overlap_t is None: - args["overlap"] = (1, overlap, overlap) - else: - args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap) - maximum = pixel_samples.shape[2] - maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum)) + spatial_overlap = overlap if overlap is not None else 64 + if overlap_t is None: + args["overlap"] = (1, spatial_overlap, spatial_overlap) + else: + args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), spatial_overlap, spatial_overlap) + maximum = pixel_samples.shape[2] + maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum)) - samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args) + samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args) + formatter = getattr(self.first_stage_model, "comfy_format_encoded", None) + if formatter is not None: + samples = formatter(samples) return samples def get_sd(self): @@ -1752,6 +1826,17 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl return (model, clip, vae) + +def _set_model_config_inference_dtype(model_config, dtype, manual_cast_dtype, device): + set_dtype = model_config.set_inference_dtype + parameters = inspect.signature(set_dtype).parameters + supports_device = "device" in parameters or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in parameters.values()) + if supports_device: + set_dtype(dtype, manual_cast_dtype, device=device) + else: + set_dtype(dtype, manual_cast_dtype) + + def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False): sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic) @@ -1859,7 +1944,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) else: manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) - model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) + _set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device) if model_config.clip_vision_prefix is not None: if output_clipvision: @@ -2000,7 +2085,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) else: manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) - model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) + _set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device) if custom_operations is not None: model_config.custom_operations = custom_operations From d54ce3d7811b41a63d9d866672a73b9f34a42484 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Thu, 11 Jun 2026 10:40:09 -0500 Subject: [PATCH 03/12] Add SeedVR2 workflow nodes --- comfy/ldm/seedvr/color_fix.py | 340 ++++++++++++ comfy_extras/nodes_seedvr.py | 997 ++++++++++++++++++++++++++++++++++ nodes.py | 1 + 3 files changed, 1338 insertions(+) create mode 100644 comfy/ldm/seedvr/color_fix.py create mode 100644 comfy_extras/nodes_seedvr.py diff --git a/comfy/ldm/seedvr/color_fix.py b/comfy/ldm/seedvr/color_fix.py new file mode 100644 index 000000000..7ddfc03af --- /dev/null +++ b/comfy/ldm/seedvr/color_fix.py @@ -0,0 +1,340 @@ +import torch +import torch.nn.functional as F +from torch import Tensor + +from comfy.ldm.seedvr.model import safe_pad_operation +from comfy.ldm.seedvr.vae import safe_interpolate_operation +from comfy.ldm.seedvr.constants import ( + CIELAB_DELTA, + CIELAB_KAPPA, + D65_WHITE_X, + D65_WHITE_Z, + WAVELET_DECOMP_LEVELS, +) + + +def wavelet_blur(image: Tensor, radius): + max_safe_radius = max(1, min(image.shape[-2:]) // 8) + if radius > max_safe_radius: + radius = max_safe_radius + + num_channels = image.shape[1] + + kernel_vals = [ + [0.0625, 0.125, 0.0625], + [0.125, 0.25, 0.125], + [0.0625, 0.125, 0.0625], + ] + kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) + kernel = kernel[None, None].repeat(num_channels, 1, 1, 1) + + image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate') + output = F.conv2d(image, kernel, groups=num_channels, dilation=radius) + + return output + +def wavelet_decomposition(image: Tensor, levels: int = WAVELET_DECOMP_LEVELS): + high_freq = torch.zeros_like(image) + + for i in range(levels): + radius = 2 ** i + low_freq = wavelet_blur(image, radius) + high_freq.add_(image).sub_(low_freq) + image = low_freq + + return high_freq, low_freq + +def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor: + + if content_feat.shape != style_feat.shape: + # Resize style to match content spatial dimensions + if len(content_feat.shape) >= 3: + # safe_interpolate_operation handles FP16 conversion automatically + style_feat = safe_interpolate_operation( + style_feat, + size=content_feat.shape[-2:], + mode='bilinear', + align_corners=False + ) + + # Decompose both features into frequency components + content_high_freq, content_low_freq = wavelet_decomposition(content_feat) + del content_low_freq # Free memory immediately + + style_high_freq, style_low_freq = wavelet_decomposition(style_feat) + del style_high_freq # Free memory immediately + + if content_high_freq.shape != style_low_freq.shape: + style_low_freq = safe_interpolate_operation( + style_low_freq, + size=content_high_freq.shape[-2:], + mode='bilinear', + align_corners=False + ) + + content_high_freq.add_(style_low_freq) + + return content_high_freq.clamp_(-1.0, 1.0) + +def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor: + original_shape = source.shape + + # Flatten + source_flat = source.flatten() + reference_flat = reference.flatten() + + # Sort both arrays + source_sorted, source_indices = torch.sort(source_flat) + reference_sorted, _ = torch.sort(reference_flat) + del reference_flat + + # Quantile mapping + n_source = len(source_sorted) + n_reference = len(reference_sorted) + + if n_source == n_reference: + matched_sorted = reference_sorted + else: + # Interpolate reference to match source quantiles + source_quantiles = torch.linspace(0, 1, n_source, device=device) + ref_indices = (source_quantiles * (n_reference - 1)).long() + ref_indices.clamp_(0, n_reference - 1) + matched_sorted = reference_sorted[ref_indices] + del source_quantiles, ref_indices, reference_sorted + + del source_sorted, source_flat + + # Reconstruct using argsort (portable across CUDA/ROCm/MPS) + inverse_indices = torch.argsort(source_indices) + del source_indices + matched_flat = matched_sorted[inverse_indices] + del matched_sorted, inverse_indices + + return matched_flat.reshape(original_shape) + +def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor: + """Convert batch of CIELAB images to RGB color space.""" + L, a, b = lab[:, 0], lab[:, 1], lab[:, 2] + + # LAB to XYZ + fy = (L + 16.0) / 116.0 + fx = a.div(500.0).add_(fy) + fz = fy - b / 200.0 + del L, a, b + + # XYZ transformation + x = torch.where( + fx > epsilon, + torch.pow(fx, 3.0), + fx.mul(116.0).sub_(16.0).div_(kappa) + ) + y = torch.where( + fy > epsilon, + torch.pow(fy, 3.0), + fy.mul(116.0).sub_(16.0).div_(kappa) + ) + z = torch.where( + fz > epsilon, + torch.pow(fz, 3.0), + fz.mul(116.0).sub_(16.0).div_(kappa) + ) + del fx, fy, fz + + # Apply D65 white point (in-place) + x.mul_(D65_WHITE_X) + # y *= 1.00000 # (no-op, skip) + z.mul_(D65_WHITE_Z) + + xyz = torch.stack([x, y, z], dim=1) + del x, y, z + + # Matrix multiplication: XYZ -> RGB + B, C, H, W = xyz.shape + xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3) + del xyz + + # Ensure dtype consistency for matrix multiplication + xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype) + rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T) + del xyz_flat + + rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2) + del rgb_linear_flat + + # Apply inverse gamma correction (delinearize) + mask = rgb_linear > 0.0031308 + rgb = torch.where( + mask, + torch.pow(torch.clamp(rgb_linear, min=0.0), 1.0 / 2.4).mul_(1.055).sub_(0.055), + rgb_linear * 12.92 + ) + del mask, rgb_linear + + return torch.clamp(rgb, 0.0, 1.0) + +def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor: + """Convert batch of RGB images to CIELAB color space using D65 illuminant.""" + # Apply sRGB gamma correction (linearize) + mask = rgb > 0.04045 + rgb_linear = torch.where( + mask, + torch.pow((rgb + 0.055) / 1.055, 2.4), + rgb / 12.92 + ) + del mask + + # Matrix multiplication: RGB -> XYZ + B, C, H, W = rgb_linear.shape + rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3) + del rgb_linear + + # Ensure dtype consistency for matrix multiplication + rgb_flat = rgb_flat.to(dtype=matrix.dtype) + xyz_flat = torch.matmul(rgb_flat, matrix.T) + del rgb_flat + + xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2) + del xyz_flat + + # Normalize by D65 white point (in-place) + xyz[:, 0].div_(D65_WHITE_X) # X + # xyz[:, 1] /= 1.00000 # Y (no-op, skip) + xyz[:, 2].div_(D65_WHITE_Z) # Z + + # XYZ to LAB transformation + epsilon_cubed = epsilon ** 3 + mask = xyz > epsilon_cubed + f_xyz = torch.where( + mask, + torch.pow(xyz, 1.0 / 3.0), + xyz.mul(kappa).add_(16.0).div_(116.0) + ) + del xyz, mask + + # Extract channels and compute LAB + L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100] + a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127] + b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127] + del f_xyz + + return torch.stack([L, a, b], dim=1) + +def lab_color_transfer( + content_feat: Tensor, + style_feat: Tensor, + luminance_weight: float = 0.8 +) -> Tensor: + content_feat = wavelet_reconstruction(content_feat, style_feat) + + if content_feat.shape != style_feat.shape: + style_feat = safe_interpolate_operation( + style_feat, + size=content_feat.shape[-2:], + mode='bilinear', + align_corners=False + ) + + device = content_feat.device + + def ensure_float32_precision(c): + orig_dtype = c.dtype + c = c.float() + return c, orig_dtype + content_feat, original_dtype = ensure_float32_precision(content_feat) + style_feat, _ = ensure_float32_precision(style_feat) + + rgb_to_xyz_matrix = torch.tensor([ + [0.4124564, 0.3575761, 0.1804375], + [0.2126729, 0.7151522, 0.0721750], + [0.0193339, 0.1191920, 0.9503041] + ], dtype=torch.float32, device=device) + + xyz_to_rgb_matrix = torch.tensor([ + [ 3.2404542, -1.5371385, -0.4985314], + [-0.9692660, 1.8760108, 0.0415560], + [ 0.0556434, -0.2040259, 1.0572252] + ], dtype=torch.float32, device=device) + + epsilon = CIELAB_DELTA + kappa = CIELAB_KAPPA + + content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0) + style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0) + + # Convert to LAB color space + content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa) + del content_feat + + style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa) + del style_feat, rgb_to_xyz_matrix + + # Match chrominance channels (a*, b*) for accurate color transfer + matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device) + matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device) + + # Handle luminance with weighted blending + if luminance_weight < 1.0: + # Partially match luminance for better overall color accuracy + matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0], device) + # Blend: preserve some content L* for detail, adopt some style L* for color + result_L = content_lab[:, 0].mul(luminance_weight).add_(matched_L.mul(1.0 - luminance_weight)) + del matched_L + else: + # Fully preserve content luminance + result_L = content_lab[:, 0] + + del content_lab, style_lab + + # Reconstruct LAB with corrected channels + result_lab = torch.stack([result_L, matched_a, matched_b], dim=1) + del result_L, matched_a, matched_b + + # Convert back to RGB + result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa) + del result_lab, xyz_to_rgb_matrix + + # Convert back to [-1, 1] range (in-place) + result = result_rgb.mul_(2.0).sub_(1.0) + del result_rgb + + result = result.to(original_dtype) + + return result + + +def wavelet_color_transfer(content_feat: Tensor, style_feat: Tensor) -> Tensor: + return wavelet_reconstruction(content_feat, style_feat) + + +def adain_color_transfer(content_feat: Tensor, style_feat: Tensor, eps: float = 1e-5) -> Tensor: + if content_feat.shape != style_feat.shape: + style_feat = safe_interpolate_operation( + style_feat, + size=content_feat.shape[-2:], + mode='bilinear', + align_corners=False, + ) + + original_dtype = content_feat.dtype + content_feat = content_feat.float() + style_feat = style_feat.float() + + b, c = content_feat.shape[:2] + content_flat = content_feat.reshape(b, c, -1) + style_flat = style_feat.reshape(b, c, -1) + + content_mean = content_flat.mean(dim=2).reshape(b, c, 1, 1) + content_std = (content_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1) + style_mean = style_flat.mean(dim=2).reshape(b, c, 1, 1) + style_std = (style_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1) + del content_flat, style_flat + + normalized = (content_feat - content_mean) / content_std + del content_mean, content_std + result = normalized * style_std + style_mean + del normalized, style_mean, style_std + + result = result.clamp_(-1.0, 1.0) + if result.dtype != original_dtype: + result = result.to(original_dtype) + return result diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py new file mode 100644 index 000000000..978de3e41 --- /dev/null +++ b/comfy_extras/nodes_seedvr.py @@ -0,0 +1,997 @@ +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io +import torch +import math +import logging +from einops import rearrange + +import comfy.model_management +import comfy.sample +import comfy.samplers +from comfy.ldm.seedvr.color_fix import ( + adain_color_transfer, + lab_color_transfer, + wavelet_color_transfer, +) +from comfy.ldm.seedvr.constants import ( + SEEDVR2_ADAIN_SCALE_MULTIPLIER, + SEEDVR2_CHUNK_FRAMES_PER_GB, + SEEDVR2_CHUNK_GB_MARGIN, + SEEDVR2_COLOR_MEM_HEADROOM, + SEEDVR2_COND_CHANNELS, + SEEDVR2_DTYPE_BYTES_FLOOR, + SEEDVR2_LAB_SCALE_MULTIPLIER, + SEEDVR2_LATENT_CHANNELS, + SEEDVR2_OOM_BACKOFF_DIVISOR, + SEEDVR2_WAVELET_SCALE_MULTIPLIER, +) + +from torchvision.transforms import functional as TVF +from torchvision.transforms import Lambda +from torchvision.transforms.functional import InterpolationMode + + +_SEEDVR2_INVALID_MODEL_MSG_PREFIX = ( + "SeedVR2Conditioning: model object does not match expected SeedVR2 structure" +) + +# Private sentinel for getattr default: distinguishes "attribute missing" +# from "attribute present but None" so the failure message is accurate. +_ATTR_MISSING = object() + + +def _seedvr2_vram_seed_frames_per_chunk(free_bytes, t_pixel): + """Predict the largest 4n+1 pixel-frame chunk that fits in free_bytes.""" + free_gb = free_bytes / (1024 ** 3) + predicted = SEEDVR2_CHUNK_FRAMES_PER_GB * (free_gb - SEEDVR2_CHUNK_GB_MARGIN) + # round (not floor) to 4n+1: the fit's central prediction lands on measured n_max + n = round((predicted - 1) / 4) + seed = 4 * int(n) + 1 + seed = max(1, min(seed, t_pixel)) + return seed + + +def _seedvr2_auto_chunk_attempts(t_latent, t_pixel, frames_per_chunk): + """Return stricter 4n+1 frame chunk sizes for auto OOM retries.""" + attempts = [frames_per_chunk] + current_chunk_latent = ( + t_latent if t_pixel <= frames_per_chunk + else (frames_per_chunk - 1) // 4 + 1 + ) + current_chunk_count = max(1, math.ceil(t_latent / current_chunk_latent)) + seen = {frames_per_chunk} + + for target_chunks in range(max(2, current_chunk_count + 1), t_latent + 1): + chunk_latent = max(1, math.ceil(t_latent / target_chunks)) + candidate = 4 * (chunk_latent - 1) + 1 + if candidate in seen: + continue + if candidate >= attempts[-1]: + continue + attempts.append(candidate) + seen.add(candidate) + + return attempts + + +def _resolve_seedvr2_diffusion_model(model): + """Resolve ``model.model.diffusion_model``, failing loud via the ``_ATTR_MISSING`` sentinel so each of the four modes (model/diffusion_model missing vs None) gives an accurate message.""" + inner = getattr(model, "model", _ATTR_MISSING) + if inner is _ATTR_MISSING: + raise RuntimeError( + f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: input has no 'model' attribute " + f"(got type {type(model).__name__})." + ) + if inner is None: + raise RuntimeError( + f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: input.model is None " + f"(input type {type(model).__name__})." + ) + diffusion_model = getattr(inner, "diffusion_model", _ATTR_MISSING) + if diffusion_model is _ATTR_MISSING: + raise RuntimeError( + f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: 'model.model' has no " + f"'diffusion_model' attribute (got type {type(inner).__name__})." + ) + if diffusion_model is None: + raise RuntimeError( + f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: 'model.model.diffusion_model' " + f"is None (model.model type {type(inner).__name__})." + ) + return diffusion_model + + +def _apply_rope_freqs_float32_cast(diffusion_model): + """Cast every module's ``rope.freqs`` to float32; the per-tensor dtype check (not a sentinel attr) self-corrects across Comfy's unload/reload, which would otherwise restore the archived fp16/bf16 dtype.""" + for module in diffusion_model.modules(): + if hasattr(module, 'rope') and hasattr(module.rope, 'freqs'): + if module.rope.freqs.data.dtype != torch.float32: + module.rope.freqs.data = module.rope.freqs.data.to(torch.float32) + + +def get_conditions(latent, latent_blur): + t, h, w, c = latent.shape + cond = torch.ones([t, h, w, c + 1], device=latent.device, dtype=latent.dtype) + cond[:, ..., :-1] = latent_blur[:] + cond[:, ..., -1:] = 1.0 + return cond + +def div_pad(image, factor): + + height_factor, width_factor = factor + height, width = image.shape[-2:] + + pad_height = (height_factor - (height % height_factor)) % height_factor + pad_width = (width_factor - (width % width_factor)) % width_factor + + if pad_height == 0 and pad_width == 0: + return image + + if isinstance(image, torch.Tensor): + padding = (0, pad_width, 0, pad_height) + image = torch.nn.functional.pad(image, padding, mode='constant', value=0.0) + + return image + +def cut_videos(videos): + t = videos.size(1) + if t == 1: + return videos + if t <= 4 : + padding = [videos[:, -1].unsqueeze(1)] * (4 - t + 1) + padding = torch.cat(padding, dim=1) + videos = torch.cat([videos, padding], dim=1) + return videos + if (t - 1) % (4) == 0: + return videos + else: + padding = [videos[:, -1].unsqueeze(1)] * ( + 4 - ((t - 1) % (4)) + ) + padding = torch.cat(padding, dim=1) + videos = torch.cat([videos, padding], dim=1) + assert (videos.size(1) - 1) % (4) == 0 + return videos + +def _seedvr2_input_shorter_edge(images, node_name): + if images.dim() == 4: + return min(images.shape[1], images.shape[2]) + if images.dim() == 5: + return min(images.shape[2], images.shape[3]) + raise ValueError( + f"{node_name}: expected 4-D or 5-D IMAGE tensor, " + f"got shape {tuple(images.shape)}" + ) + + +def _seedvr2_pad(images, upscaled_shorter_edge, node_name): + if upscaled_shorter_edge < 2: + raise ValueError( + f"{node_name}: input shorter edge must be at least 2 pixels; " + f"got {upscaled_shorter_edge}." + ) + if images.shape[-1] > 3: + images = images[..., :3] + if images.dim() == 4: + # Comfy video components arrive as a 4-D IMAGE frame sequence: + # (frames, H, W, C). SeedVR2 consumes that as one video. + images = images.unsqueeze(0) + elif images.dim() != 5: + raise ValueError( + f"{node_name}: expected 4-D or 5-D IMAGE tensor, " + f"got shape {tuple(images.shape)}" + ) + images = images.permute(0, 1, 4, 2, 3) + + b, t, c, h, w = images.shape + images = images.reshape(b * t, c, h, w) + + clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0)) + images = clip(images) + images = div_pad(images, (16, 16)) + _, _, new_h, new_w = images.shape + + images = images.reshape(b, t, c, new_h, new_w) + images = cut_videos(images) + images_bthwc = rearrange(images, "b t c h w -> b t h w c") + + return io.NodeOutput(images_bthwc) + + +class SeedVR2Preprocess(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SeedVR2Preprocess", + display_name="Pre-Process SeedVR2 Input", + category="image/upscaling", + description="Pad a resized image for SeedVR2 model. Alpha channel is dropped. The node Post-Process SeedVR2 Output re-applies it from the original resized image.", + search_aliases=["seedvr2", "upscale", "video upscale", "pad", "preprocess"], + inputs=[ + io.Image.Input("resized_images", tooltip="The resized image to process."), + ], + outputs=[ + io.Image.Output("images", tooltip="The padded image for VAE encoding."), + ] + ) + + @classmethod + def execute(cls, resized_images): + upscaled_shorter_edge = _seedvr2_input_shorter_edge(resized_images, "SeedVR2Preprocess") + return _seedvr2_pad( + resized_images, upscaled_shorter_edge, "SeedVR2Preprocess", + ) + + +class SeedVR2PostProcessing(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SeedVR2PostProcessing", + display_name="Post-Process SeedVR2 Output", + category="image/upscaling", + description="Align the generated image with the original resized image and apply color correction.", + search_aliases=["seedvr2", "upscale", "color correction", "color match", "postprocess"], + inputs=[ + io.Image.Input("images", tooltip="The generated image to process."), + io.Image.Input("original_resized_images", tooltip="The original resized image before pre-processing, used as reference."), + io.Combo.Input("color_correction_method", options=["lab", "wavelet", "adain", "none"], default="lab", tooltip="Method to match the generated image colors to the original image. lab: transfer color in CIELAB space, preserving detail (most faithful). wavelet: transfer low-frequency color, keeping upscaled high-frequency detail. adain: match per-channel mean/std (fastest, global tint). none: skip color transfer (geometry alignment only)."), + ], + outputs=[io.Image.Output(display_name="images", tooltip="The aligned, color-corrected image.")], + ) + + @classmethod + def execute(cls, images, original_resized_images, color_correction_method): + alpha_input = None + if original_resized_images.shape[-1] == 4: + alpha_input = original_resized_images[..., 3:4] + original_resized_images = original_resized_images[..., :3] + decoded_5d, decoded_was_4d = cls._as_bthwc(images) + reference_full, _ = cls._as_bthwc(original_resized_images) + decoded_5d = cls._restore_reference_batch_time(decoded_5d, reference_full) + + b = min(decoded_5d.shape[0], reference_full.shape[0]) + t = min(decoded_5d.shape[1], reference_full.shape[1]) + reference_h = reference_full.shape[2] + reference_w = reference_full.shape[3] + + decoded_5d = decoded_5d[:b, :t, :, :, :] + target_h = min(decoded_5d.shape[2], reference_h) + target_w = min(decoded_5d.shape[3], reference_w) + decoded_5d = decoded_5d[:, :, :target_h, :target_w, :] + if color_correction_method in ("lab", "wavelet", "adain"): + reference_5d = reference_full[:b, :t, :, :, :] + reference_5d = cls._resize_reference(reference_5d, target_h, target_w) + output_device = decoded_5d.device + decoded_raw = cls._to_seedvr2_raw(decoded_5d) + reference_raw = cls._to_seedvr2_raw(reference_5d) + decoded_flat = rearrange(decoded_raw, "b t h w c -> (b t) c h w") + reference_flat = rearrange(reference_raw, "b t h w c -> (b t) c h w") + output = cls._color_transfer_chunked( + decoded_flat, reference_flat, output_device, color_correction_method, + ) + output = rearrange(output, "(b t) c h w -> b t h w c", b=b, t=t) + output = output.add(1.0).div(2.0).clamp(0.0, 1.0) + elif color_correction_method == "none": + output = decoded_5d + else: + raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}") + + if alpha_input is not None: + alpha_5d, _ = cls._as_bthwc(alpha_input) + alpha_5d = alpha_5d[:output.shape[0], :output.shape[1], :output.shape[2], :output.shape[3], :] + output = torch.cat([output, alpha_5d.to(dtype=output.dtype, device=output.device)], dim=-1) + h2 = output.shape[-3] - (output.shape[-3] % 2) + w2 = output.shape[-2] - (output.shape[-2] % 2) + output = output[:, :, :h2, :w2, :] + if decoded_was_4d: + output = output.reshape(-1, output.shape[-3], output.shape[-2], output.shape[-1]) + return io.NodeOutput(output) + + @staticmethod + def _as_bthwc(images): + if images.ndim == 4: + return images.unsqueeze(0), True + if images.ndim == 5: + return images, False + raise ValueError( + f"SeedVR2PostProcessing: expected 4-D or 5-D IMAGE tensor, got shape {tuple(images.shape)}" + ) + + @staticmethod + def _restore_reference_batch_time(decoded, reference): + if decoded.shape[0] != 1: + return decoded + ref_b, ref_t = reference.shape[:2] + if ref_b < 1 or decoded.shape[1] % ref_b != 0: + return decoded + decoded_t = decoded.shape[1] // ref_b + if decoded_t < ref_t: + return decoded + return decoded.reshape(ref_b, decoded_t, decoded.shape[2], decoded.shape[3], decoded.shape[4]) + + @staticmethod + def _to_seedvr2_raw(images): + return images.mul(2.0).sub(1.0) + + @staticmethod + def _color_transfer_on_vae_device(decoded_flat, reference_flat, output_device, transfer_fn): + color_device = comfy.model_management.vae_device() + decoded_flat = decoded_flat.to(device=color_device) + reference_flat = reference_flat.to(device=color_device) + output = transfer_fn(decoded_flat, reference_flat) + return output.to(device=output_device) + + @staticmethod + def _lab_color_transfer_on_vae_device(decoded_flat, reference_flat, output_device): + color_device = comfy.model_management.vae_device() + result = None + for start in range(decoded_flat.shape[0]): + decoded_frame = decoded_flat[start:start + 1].to(device=color_device).clone() + reference_frame = reference_flat[start:start + 1].to(device=color_device).clone() + output = lab_color_transfer(decoded_frame, reference_frame).to(device=output_device) + if result is None: + result = torch.empty( + (decoded_flat.shape[0],) + tuple(output.shape[1:]), + device=output_device, + dtype=output.dtype, + ) + result[start:start + 1].copy_(output) + if result is None: + raise ValueError("SeedVR2PostProcessing: LAB color correction requires at least one frame.") + return result + + @classmethod + def _color_transfer_chunked(cls, decoded_flat, reference_flat, output_device, color_correction_method): + chunk_size = cls._estimate_color_correction_chunk_size(decoded_flat, color_correction_method) + while True: + next_chunk_size = None + try: + return cls._run_color_transfer_chunks( + decoded_flat, reference_flat, output_device, color_correction_method, chunk_size, + ) + except Exception as e: + comfy.model_management.raise_non_oom(e) + if chunk_size <= 1: + raise RuntimeError( + "SeedVR2PostProcessing: color correction OOM at one frame; " + f"color_correction_method={color_correction_method}, shape={tuple(decoded_flat.shape)}." + ) from e + next_chunk_size = max(1, chunk_size // SEEDVR2_OOM_BACKOFF_DIVISOR) + + comfy.model_management.soft_empty_cache() + chunk_size = next_chunk_size + + @classmethod + def _run_color_transfer_chunks(cls, decoded_flat, reference_flat, output_device, color_correction_method, chunk_size): + result = None + for start in range(0, decoded_flat.shape[0], chunk_size): + end = min(start + chunk_size, decoded_flat.shape[0]) + decoded_chunk = decoded_flat[start:end] + reference_chunk = reference_flat[start:end] + if color_correction_method == "lab": + output = cls._lab_color_transfer_on_vae_device(decoded_chunk, reference_chunk, output_device) + elif color_correction_method == "wavelet": + output = cls._color_transfer_on_vae_device( + decoded_chunk, reference_chunk, output_device, wavelet_color_transfer, + ) + else: + output = cls._color_transfer_on_vae_device( + decoded_chunk, reference_chunk, output_device, adain_color_transfer, + ) + if result is None: + result = torch.empty( + (decoded_flat.shape[0],) + tuple(output.shape[1:]), + device=output_device, + dtype=output.dtype, + ) + result[start:end].copy_(output) + if result is None: + raise ValueError("SeedVR2PostProcessing: color correction requires at least one frame.") + return result + + @classmethod + def _estimate_color_correction_chunk_size(cls, decoded_flat, color_correction_method): + multiplier = cls._color_correction_memory_multiplier(color_correction_method) + frames = decoded_flat.shape[0] + _, channels, height, width = decoded_flat.shape + dtype_bytes = max(decoded_flat.element_size(), SEEDVR2_DTYPE_BYTES_FLOOR) + bytes_per_frame = height * width * channels * dtype_bytes * multiplier + if bytes_per_frame <= 0: + return frames + color_device = comfy.model_management.vae_device() + free_memory = comfy.model_management.get_free_memory(color_device) + chunk_size = int((free_memory * SEEDVR2_COLOR_MEM_HEADROOM) // bytes_per_frame) + return max(1, min(frames, chunk_size)) + + @staticmethod + def _color_correction_memory_multiplier(color_correction_method): + if color_correction_method == "lab": + return SEEDVR2_LAB_SCALE_MULTIPLIER + if color_correction_method == "wavelet": + return SEEDVR2_WAVELET_SCALE_MULTIPLIER + if color_correction_method == "adain": + return SEEDVR2_ADAIN_SCALE_MULTIPLIER + raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}") + + @staticmethod + def _resize_reference(reference, height, width): + if reference.shape[2] == height and reference.shape[3] == width: + return reference + b, t = reference.shape[:2] + reference_flat = rearrange(reference, "b t h w c -> (b t) c h w") + resized = TVF.resize( + reference_flat, + size=(height, width), + interpolation=InterpolationMode.BICUBIC, + antialias=not (isinstance(reference_flat, torch.Tensor) and reference_flat.device.type == "mps"), + ) + return rearrange(resized, "(b t) c h w -> b t h w c", b=b, t=t) + + +class SeedVR2Conditioning(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SeedVR2Conditioning", + display_name="Apply SeedVR2 Conditioning", + category="conditioning", + description="Build SeedVR2 positive/negative conditioning from a VAE latent.", + search_aliases=["seedvr2", "upscale", "conditioning"], + inputs=[ + io.Model.Input("model", tooltip="The SeedVR2 model."), + io.Latent.Input("vae_conditioning", display_name="latent"), + ], + outputs=[ + io.Model.Output(display_name="model", tooltip="The SeedVR2 model, passed through."), + io.Conditioning.Output(display_name="positive", tooltip="The positive conditioning for sampling."), + io.Conditioning.Output(display_name="negative", tooltip="The negative conditioning for sampling."), + io.Latent.Output(display_name="latent", tooltip="The latent to denoise."), + ], + ) + + @classmethod + def execute(cls, model, vae_conditioning) -> io.NodeOutput: + + vae_conditioning = vae_conditioning["samples"] + if vae_conditioning.ndim != 5: + raise ValueError( + "SeedVR2Conditioning expects a 5-D VAE latent in Comfy " + f"channel-first layout; got shape {tuple(vae_conditioning.shape)}." + ) + if vae_conditioning.shape[-1] == SEEDVR2_LATENT_CHANNELS and vae_conditioning.shape[1] != SEEDVR2_LATENT_CHANNELS: + raise ValueError( + "SeedVR2Conditioning expects SeedVR2 VAE latents in Comfy " + f"channel-first layout (B, {SEEDVR2_LATENT_CHANNELS}, T, H, W); " + f"got channel-last shape {tuple(vae_conditioning.shape)}." + ) + vae_conditioning = vae_conditioning.movedim(1, -1).contiguous() + model_patcher = model + model = _resolve_seedvr2_diffusion_model(model_patcher) + pos_cond = model.positive_conditioning + neg_cond = model.negative_conditioning + + # Fail-loud guard against silently-wrong output when a + # DiT-only ``.safetensors`` (no ``positive_conditioning`` / + # ``negative_conditioning`` keys) is loaded via ``UNETLoader``. + # ``NaDiT.__init__`` zero-fills the buffers via ``torch.zeros`` (see + # ``comfy/ldm/seedvr/model.py``); ``load_state_dict(strict=False)`` + # leaves them at zero when the keys are absent. Detect that state + # here rather than at ``BaseModel.extra_conds`` (per sampling step, + # wasteful) or at the resolver helper (mixes structural shape with + # semantic content). Both buffers must be checked together — partial + # bake regressions could populate one but not the other. + if ( + pos_cond.float().abs().sum().item() == 0 + and neg_cond.float().abs().sum().item() == 0 + ): + raise RuntimeError( + f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: positive_conditioning " + f"and negative_conditioning buffers are zero-valued — model " + f"file appears to be a DiT-only export missing " + f"the SeedVR2 conditioning tensors. " + f"Re-bake the file with ``positive_conditioning`` (58, 5120) " + f"and ``negative_conditioning`` (64, 5120) keys at top level, " + f"or load via CheckpointLoaderSimple from a bundled " + f"checkpoint." + ) + + _apply_rope_freqs_float32_cast(model) + + condition = torch.stack([get_conditions(c, c) for c in vae_conditioning]) + condition = condition.movedim(-1, 1) + latent = vae_conditioning.movedim(-1, 1) + + latent = rearrange(latent, "b c t h w -> b (c t) h w") + condition = rearrange(condition, "b c t h w -> b (c t) h w") + + negative = [[neg_cond.unsqueeze(0), {"condition": condition}]] + positive = [[pos_cond.unsqueeze(0), {"condition": condition}]] + + return io.NodeOutput(model_patcher, positive, negative, {"samples": latent}) + +def _slice_collapsed_4d_along_t(tensor_4d: torch.Tensor, t_start: int, + t_end: int, channels: int) -> torch.Tensor: + """Slice collapsed ``(B, channels*T, H, W)`` along latent T: reshape (accepts non-contiguous inputs), slice, ``.contiguous()`` (T-slice of 5D is a non-contiguous view; re-collapse needs contiguous), re-collapse.""" + B, CT, H, W = tensor_4d.shape + if CT % channels != 0: + raise ValueError( + f"_slice_collapsed_4d_along_t: collapsed channel dim {CT} is not " + f"divisible by channels={channels}; tensor shape {tuple(tensor_4d.shape)}." + ) + T = CT // channels + if not (0 <= t_start < t_end <= T): + raise ValueError( + f"_slice_collapsed_4d_along_t: slice [{t_start}:{t_end}] out of " + f"range for T={T}." + ) + new_T = t_end - t_start + sliced = tensor_4d.reshape(B, channels, T, H, W)[:, :, t_start:t_end, :, :].contiguous() + return sliced.reshape(B, channels * new_T, H, W) + + +def _slice_seedvr2_cond_along_t(cond_list, t_start: int, t_end: int): + """Return a new conditioning list with each entry's ``options["condition"]`` (collapsed ``(B, 17*T, H, W)``) sliced along latent T; text tensors, other option keys, and condition-less entries pass through unchanged and inputs are not mutated.""" + new_list = [] + for entry in cond_list: + text_cond, options = entry[0], entry[1] + if "condition" not in options: + new_list.append(entry) + continue + new_options = options.copy() + new_options["condition"] = _slice_collapsed_4d_along_t( + new_options["condition"], t_start, t_end, + SEEDVR2_COND_CHANNELS, + ) + new_list.append([text_cond, new_options]) + return new_list + + +def _slice_seedvr2_noise_mask_along_t(noise_mask: torch.Tensor, + samples_4d: torch.Tensor, + t_start: int, + t_end: int): + """Slice only masks already expanded to collapsed ``(B, 16*T, H, W)``; pass standard ``(B, 1, H, W)`` ``SetLatentNoiseMask`` outputs through for KSampler to expand.""" + if noise_mask.ndim == samples_4d.ndim and noise_mask.shape[1] == samples_4d.shape[1]: + return _slice_collapsed_4d_along_t( + noise_mask, t_start, t_end, SEEDVR2_LATENT_CHANNELS, + ) + return noise_mask + + +def _concat_chunks_along_t(chunks_4d, channels: int) -> torch.Tensor: + """Concatenate collapsed ``(B, channels*T_i, H, W)`` chunks along latent T: un-collapse to 5D, cat on ``dim=2``, re-collapse to 4D.""" + if len(chunks_4d) == 0: + raise ValueError("_concat_chunks_along_t: empty chunk list.") + fives = [] + for ch in chunks_4d: + B, CT, H, W = ch.shape + if CT % channels != 0: + raise ValueError( + f"_concat_chunks_along_t: chunk shape {tuple(ch.shape)} " + f"channel dim {CT} not divisible by channels={channels}." + ) + T = CT // channels + fives.append(ch.reshape(B, channels, T, H, W)) + cat = torch.cat(fives, dim=2).contiguous() + B, C, T_total, H, W = cat.shape + return cat.reshape(B, C * T_total, H, W) + + +def _hann_blend_weights_1d(overlap: int, device, dtype) -> torch.Tensor: + """1D length-``overlap`` crossfade weights for the previous chunk (current = ``1 - w_prev``): + Hann window with a ``[1/3, 2/3]`` dead-band for ``overlap >= 3``, linear ramp for ``overlap < 3`` + (dead-band would collapse a tiny transition). Window shape matched to the reference + overlapping-frame blend for parity; caller broadcasts across ``(B, C, T_overlap, H, W)``. + """ + if overlap < 1: + raise ValueError( + f"_hann_blend_weights_1d: overlap must be >= 1; got {overlap}." + ) + if overlap >= 3: + t = torch.linspace(0.0, 1.0, steps=overlap, device=device, dtype=dtype) + blend_start = 1.0 / 3.0 + blend_end = 2.0 / 3.0 + u = ((t - blend_start) / (blend_end - blend_start)).clamp(0.0, 1.0) + return 0.5 + 0.5 * torch.cos(torch.pi * u) + return torch.linspace(1.0, 0.0, steps=overlap, device=device, dtype=dtype) + + +def _blend_overlap_region(prev_tail_5d: torch.Tensor, + cur_head_5d: torch.Tensor) -> torch.Tensor: + """Blend two equal-shape 5D ``(B, C, T_overlap, H, W)`` tensors with a 1D Hann/linear T-ramp: ``prev_tail_5d`` takes the descending weight, ``cur_head_5d`` takes ``1 - w_prev`` (caller ensures matching shape/dtype/device).""" + if prev_tail_5d.shape != cur_head_5d.shape: + raise ValueError( + f"_blend_overlap_region: shape mismatch " + f"prev {tuple(prev_tail_5d.shape)} vs " + f"cur {tuple(cur_head_5d.shape)}." + ) + overlap = int(prev_tail_5d.shape[2]) + w_prev_1d = _hann_blend_weights_1d( + overlap, prev_tail_5d.device, prev_tail_5d.dtype, + ) + # Reshape to (1, 1, overlap, 1, 1) for broadcast across B, C, H, W. + w_prev = w_prev_1d.view(1, 1, overlap, 1, 1) + w_cur = 1.0 - w_prev + return prev_tail_5d * w_prev + cur_head_5d * w_cur + + +def _concat_chunks_with_overlap_blend(chunk_specs, channels: int, + overlap_latent: int) -> torch.Tensor: + """Concatenate overlapping ``(t_start, t_end, chunk_4d)`` specs (source-latent T coords) into one collapsed 4D tensor, Hann/linear-blending overlaps; ``overlap_latent == 0`` fast-paths to plain concat (bit-identical to ``_concat_chunks_along_t``). Each blend uses the actual width ``min(prev_end - cur_start, chunk length)``, smaller than ``overlap_latent`` for a runt final chunk.""" + if len(chunk_specs) == 0: + raise ValueError("_concat_chunks_with_overlap_blend: empty chunk list.") + if overlap_latent < 0: + raise ValueError( + f"_concat_chunks_with_overlap_blend: overlap_latent must be " + f">= 0; got {overlap_latent}." + ) + + # Validate channel divisibility once and capture per-chunk T. + chunk_5d = [] + for t_start, t_end, ch in chunk_specs: + B, CT, H, W = ch.shape + if CT % channels != 0: + raise ValueError( + f"_concat_chunks_with_overlap_blend: chunk shape " + f"{tuple(ch.shape)} channel dim {CT} not divisible " + f"by channels={channels}." + ) + T = CT // channels + if t_end - t_start != T: + raise ValueError( + f"_concat_chunks_with_overlap_blend: chunk T={T} mismatches " + f"declared range [{t_start}:{t_end}]." + ) + chunk_5d.append((t_start, t_end, ch.reshape(B, channels, T, H, W))) + + if overlap_latent == 0: + # Fast path: pure concat in the caller-provided chunk order. + return _concat_chunks_along_t( + [c.reshape(c.shape[0], channels * c.shape[2], c.shape[3], c.shape[4]) + for _, _, c in chunk_5d], + channels, + ) + + T_total = max(t_end for _, t_end, _ in chunk_5d) + first_5d = chunk_5d[0][2] + B = first_5d.shape[0] + H = first_5d.shape[3] + W = first_5d.shape[4] + result = torch.empty( + (B, channels, T_total, H, W), + device=first_5d.device, dtype=first_5d.dtype, + ) + filled_until = 0 + for i, (cs, ce, ct_5d) in enumerate(chunk_5d): + chunk_T = int(ct_5d.shape[2]) + if i == 0: + result[:, :, cs:ce, :, :] = ct_5d + filled_until = ce + continue + # Overlap region width is bounded by both the previous fill + # frontier and the current chunk's actual length (for runt + # final chunks shorter than the configured overlap). + overlap_len = min(filled_until - cs, chunk_T) + if overlap_len > 0: + prev_tail = result[:, :, cs:cs + overlap_len, :, :].contiguous() + cur_head = ct_5d[:, :, :overlap_len, :, :].contiguous() + blended = _blend_overlap_region(prev_tail, cur_head) + result[:, :, cs:cs + overlap_len, :, :] = blended + tail_start = cs + overlap_len + tail_end = ce + if tail_end > tail_start: + result[:, :, tail_start:tail_end, :, :] = ( + ct_5d[:, :, overlap_len:, :, :] + ) + else: + # Disjoint chunks (overlap_latent set but this pair did not + # actually overlap, e.g. step_latent equal to chunk_latent + # in a degenerate config). Treat as concat. + result[:, :, cs:ce, :, :] = ct_5d + filled_until = ce + + return result.contiguous().reshape(B, channels * T_total, H, W) + + +def _run_standard_sample(model, seed: int, steps: int, cfg: float, + sampler_name: str, scheduler: str, + positive, negative, latent: dict, + denoise: float) -> dict: + """Single-shot mirror of ``nodes.py:common_ksampler`` (seed -> noise, ``comfy.sample.sample``, latent dict); used by the ProgressiveSampler short-circuit when the whole sequence fits one chunk.""" + samples_in = latent["samples"] + samples_in = comfy.sample.fix_empty_latent_channels( + model, samples_in, latent.get("downscale_ratio_spacial", None), + ) + batch_inds = latent.get("batch_index", None) + noise = comfy.sample.prepare_noise(samples_in, seed, batch_inds) + noise_mask = latent.get("noise_mask", None) + samples = comfy.sample.sample( + model, noise, steps, cfg, sampler_name, scheduler, + positive, negative, samples_in, + denoise=denoise, noise_mask=noise_mask, seed=seed, + ) + out = latent.copy() + out.pop("downscale_ratio_spacial", None) + out["samples"] = samples + return out + + +class SeedVR2ProgressiveSampler(io.ComfyNode): + """Sequential temporal chunking sampler for SeedVR2 native. + + Drop-in replacement for ``KSampler`` in SeedVR2 native workflows that + OOM on long sequences. The latent enters the sampler in SeedVR2's + collapsed form ``(B, 16*T, H, W)`` (collapsed by ``SeedVR2Conditioning`` + at ``rearrange(b c t h w -> b (c t) h w)``); this node slices that + tensor along the temporal axis, runs the configured inner sampler + sequentially per chunk against the standard ``comfy.sample.sample`` + entry point, and concatenates per-chunk outputs back into a single + ``(B, 16*T_total, H, W)`` latent. + + ``frames_per_chunk`` is expressed in pixel-frame units to match the + SeedVR2 4n+1 constraint enforced upstream by ``cut_videos`` and the + VAE's ``temporal_downsample_factor=4``. A pixel chunk size ``F`` + maps to ``(F - 1) // 4 + 1`` latent-frame chunks. + + Determinism contract: a single noise tensor is generated once from + the user seed and sliced per chunk (rather than re-seeding each + chunk), so a workflow that fits in a single chunk produces output + identical to a workflow that fits in N chunks at the same seed, + modulo the inherent T-axis chunk-boundary independence of the model. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SeedVR2ProgressiveSampler", + display_name="Sample SeedVR2 (Progressive)", + category="sampling", + description="Sample a SeedVR2 latent in sequential temporal chunks to allow longer videos to fit into VRAM via frame blending the resulting upscaled latents.", + search_aliases=["seedvr2", "upscale", "video upscale", "sampler", "chunk"], + inputs=[ + io.Model.Input("model", tooltip="The model used for denoising the input latent."), + io.Int.Input("seed", default=0, min=0, + max=0xffffffffffffffff, + control_after_generate=True, + tooltip="The random seed used for creating the noise."), + io.Int.Input("steps", default=20, min=1, max=10000, + tooltip="The number of steps used in the denoising process."), + io.Float.Input("cfg", default=1.0, min=0.0, max=100.0, + step=0.1, round=0.01, + tooltip="The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."), + io.Combo.Input("sampler_name", + options=comfy.samplers.SAMPLER_NAMES, + tooltip="The algorithm used when sampling, this can affect the quality, speed, and style of the generated output."), + io.Combo.Input("scheduler", + options=comfy.samplers.SCHEDULER_NAMES, + tooltip="The scheduler controls how noise is gradually removed to form the image."), + io.Conditioning.Input("positive", + tooltip="The conditioning describing the attributes you want to include in the image."), + io.Conditioning.Input("negative", + tooltip="The conditioning describing the attributes you want to exclude from the image."), + io.Latent.Input("latent", + tooltip="The latent image to denoise."), + io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, + step=0.01, + tooltip="The amount of denoising applied, lower values will maintain the structure of the initial image allowing for image to image sampling."), + io.Int.Input("frames_per_chunk", default=21, min=1, + max=16384, step=4, + tooltip="Pixel frames per temporal chunk (4n+1: 1, 5, 9, 13, ...)."), + io.Int.Input("temporal_overlap", default=0, min=0, + max=16384, + tooltip="Latent frames blended between adjacent chunks to hide the seam; 0 = no blend."), + io.Combo.Input("chunking_mode", + options=["manual", "auto"], + default="manual", + tooltip="manual = use frames_per_chunk exactly; auto = shrink the chunk until it fits in VRAM."), + ], + outputs=[io.Latent.Output(display_name="latent", tooltip="The upscaled latent.")], + ) + + @classmethod + def execute(cls, model, seed, steps, cfg, sampler_name, scheduler, + positive, negative, latent, denoise, + frames_per_chunk, temporal_overlap, + chunking_mode="manual") -> io.NodeOutput: + # 4n+1 validation in pixel-frame domain. The SeedVR2 native pipeline + # requires pixel-frame counts of the form 4n+1 (1, 5, 9, 13, ...), + # imposed at ``cut_videos`` upstream and propagated through the VAE's + # temporal_downsample_factor=4. Reject violations explicitly before + # any model invocation; a silent rounding would mis-align chunk + # boundaries with the 4n+1 lattice. + if frames_per_chunk < 1 or (frames_per_chunk - 1) % 4 != 0: + raise ValueError( + f"SeedVR2ProgressiveSampler: frames_per_chunk must be a " + f"4n+1 pixel-frame count (1, 5, 9, 13, 17, 21, ...); " + f"got {frames_per_chunk}." + ) + + samples_4d = latent["samples"] + if torch.count_nonzero(samples_4d) == 0: + raise ValueError( + "SeedVR2ProgressiveSampler: input latent is empty (all zeros). " + "SeedVR2 is an upscaler; connect an encoded latent from " + "'Apply SeedVR2 conditioning' rather than an empty latent." + ) + samples_4d = comfy.sample.fix_empty_latent_channels( + model, samples_4d, + latent.get("downscale_ratio_spacial", None), + ) + if samples_4d.ndim != 4: + raise ValueError( + f"SeedVR2ProgressiveSampler: expected 4D collapsed latent " + f"(B, 16*T, H, W); got shape {tuple(samples_4d.shape)}." + ) + B, CT, H, W = samples_4d.shape + if CT % SEEDVR2_LATENT_CHANNELS != 0: + raise ValueError( + f"SeedVR2ProgressiveSampler: collapsed channel dim {CT} is " + f"not divisible by SeedVR2 latent channels " + f"{SEEDVR2_LATENT_CHANNELS}; latent does not appear to be " + f"SeedVR2-shaped." + ) + T_latent = CT // SEEDVR2_LATENT_CHANNELS + T_pixel = 4 * (T_latent - 1) + 1 + + if chunking_mode not in ("manual", "auto"): + raise ValueError( + f"SeedVR2ProgressiveSampler: chunking_mode must be " + f"'manual' or 'auto'; got {chunking_mode!r}." + ) + + if chunking_mode == "auto": + free_memory = comfy.model_management.get_free_memory(model.load_device) + seed_frames_per_chunk = _seedvr2_vram_seed_frames_per_chunk( + free_memory, T_pixel, + ) + logging.info( + "SeedVR2ProgressiveSampler auto: free=%.2fGB -> seeding " + "frames_per_chunk=%s (4n+1; T_pixel=%s).", + free_memory / (1024 ** 3), seed_frames_per_chunk, T_pixel, + ) + attempts = _seedvr2_auto_chunk_attempts( + T_latent, T_pixel, seed_frames_per_chunk, + ) + for i, attempt_frames_per_chunk in enumerate(attempts): + retry = False + try: + return cls.execute( + model=model, seed=seed, steps=steps, cfg=cfg, + sampler_name=sampler_name, scheduler=scheduler, + positive=positive, negative=negative, + latent=latent, denoise=denoise, + frames_per_chunk=attempt_frames_per_chunk, + temporal_overlap=temporal_overlap, + chunking_mode="manual", + ) + except Exception as e: + comfy.model_management.raise_non_oom(e) + if i == len(attempts) - 1: + raise RuntimeError( + "SeedVR2ProgressiveSampler: exhausted auto " + "chunking attempts after OOM. Tried " + f"frames_per_chunk values {attempts}." + ) from e + retry = True + + if retry: + logging.warning( + "SeedVR2ProgressiveSampler auto chunking OOM at " + "frames_per_chunk=%s; retrying with " + "frames_per_chunk=%s.", + attempt_frames_per_chunk, attempts[i + 1], + ) + comfy.model_management.soft_empty_cache() + + # Short-circuit: total fits in one chunk -> standard path with no + # chunking overhead. Output of this branch is byte-identical to the + # built-in KSampler given the same (model, seed, steps, cfg, + # sampler_name, scheduler, positive, negative, latent, + # denoise) tuple. + if T_pixel <= frames_per_chunk: + return io.NodeOutput(_run_standard_sample( + model, seed, steps, cfg, sampler_name, scheduler, + positive, negative, latent, denoise, + )) + + # Map pixel chunk -> latent chunk. Each chunk's latent length is + # at most ``chunk_latent``; the final chunk may be a runt that + # is automatically 4n+1-aligned in the pixel domain by the + # T_pixel = 4*(T_latent-1) + 1 mapping (every positive integer + # T_latent corresponds to a valid 4n+1 pixel count). + chunk_latent = (frames_per_chunk - 1) // 4 + 1 + + # ``temporal_overlap`` is exposed in latent-frame units, but users + # do not know the derived latent chunk length. Treat oversized + # values as "maximum valid overlap" while preserving a strictly + # positive chunk-loop stride. + if temporal_overlap < 0: + raise ValueError( + f"SeedVR2ProgressiveSampler: temporal_overlap must be >= 0; " + f"got {temporal_overlap}." + ) + temporal_overlap = min(temporal_overlap, chunk_latent - 1) + step_latent = chunk_latent - temporal_overlap + + # Generate full noise once from the user seed, then slice along T + # per chunk. Using one global noise tensor (rather than re-seeding + # per chunk) preserves seed-determinism across chunk-count + # variations: the same (seed, total T_latent) always produces the + # same noise samples regardless of how the work is partitioned. + batch_inds = latent.get("batch_index", None) + noise_full = comfy.sample.prepare_noise(samples_4d, seed, batch_inds) + + noise_mask = latent.get("noise_mask", None) + + # Build the flat list of chunk ranges first so the chunking + # geometry is fully known before any sample call. + chunk_ranges = [] + for chunk_start in range(0, T_latent, step_latent): + chunk_end = min(chunk_start + chunk_latent, T_latent) + if chunk_start >= chunk_end: + # The final iteration of a stride that lands exactly on + # T_latent produces a zero-length chunk; skip it. + break + chunk_ranges.append((chunk_start, chunk_end)) + if chunk_end >= T_latent: + break + + def _sample_one_chunk(chunk_start, chunk_end): + samples_chunk = _slice_collapsed_4d_along_t( + samples_4d, chunk_start, chunk_end, + SEEDVR2_LATENT_CHANNELS, + ) + noise_chunk = _slice_collapsed_4d_along_t( + noise_full, chunk_start, chunk_end, + SEEDVR2_LATENT_CHANNELS, + ) + positive_chunk = _slice_seedvr2_cond_along_t( + positive, chunk_start, chunk_end, + ) + negative_chunk = _slice_seedvr2_cond_along_t( + negative, chunk_start, chunk_end, + ) + + # Per-chunk noise_mask handling: standard masks are passed + # through for KSampler expansion; pre-expanded collapsed + # masks are sliced. + chunk_noise_mask = None + if noise_mask is not None: + chunk_noise_mask = _slice_seedvr2_noise_mask_along_t( + noise_mask, samples_4d, chunk_start, chunk_end, + ) + + return comfy.sample.sample( + model, noise_chunk, steps, cfg, sampler_name, scheduler, + positive_chunk, negative_chunk, samples_chunk, + denoise=denoise, noise_mask=chunk_noise_mask, seed=seed, + ) + + chunk_specs = [] + for chunk_start, chunk_end in chunk_ranges: + chunk_samples = _sample_one_chunk(chunk_start, chunk_end) + chunk_specs.append((chunk_start, chunk_end, chunk_samples)) + + final = _concat_chunks_with_overlap_blend( + chunk_specs, SEEDVR2_LATENT_CHANNELS, temporal_overlap, + ) + + out = latent.copy() + out.pop("downscale_ratio_spacial", None) + out["samples"] = final + return io.NodeOutput(out) + + +class SeedVRExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SeedVR2Conditioning, + SeedVR2Preprocess, + SeedVR2PostProcessing, + SeedVR2ProgressiveSampler, + ] + +async def comfy_entrypoint() -> SeedVRExtension: + return SeedVRExtension() diff --git a/nodes.py b/nodes.py index 0d422d418..136071de2 100644 --- a/nodes.py +++ b/nodes.py @@ -2419,6 +2419,7 @@ async def init_builtin_extra_nodes(): "nodes_camera_trajectory.py", "nodes_edit_model.py", "nodes_tcfg.py", + "nodes_seedvr.py", "nodes_context_windows.py", "nodes_qwen.py", "nodes_chroma_radiance.py", From 0fdbc5d2604258eedeffd7ef797c8c5c91da9f0f Mon Sep 17 00:00:00 2001 From: John Pollock Date: Thu, 11 Jun 2026 10:40:28 -0500 Subject: [PATCH 04/12] Add SeedVR2 core coverage --- tests-unit/comfy_test/model_detection_test.py | 57 ++++ tests-unit/comfy_test/test_seedvr2_dtype.py | 49 +++ .../comfy_test/test_seedvr2_internals.py | 216 ++++++++++++ tests-unit/comfy_test/test_seedvr2_model.py | 307 ++++++++++++++++++ 4 files changed, 629 insertions(+) create mode 100644 tests-unit/comfy_test/test_seedvr2_dtype.py create mode 100644 tests-unit/comfy_test/test_seedvr2_internals.py create mode 100644 tests-unit/comfy_test/test_seedvr2_model.py diff --git a/tests-unit/comfy_test/model_detection_test.py b/tests-unit/comfy_test/model_detection_test.py index 4e9350602..109e2b13b 100644 --- a/tests-unit/comfy_test/model_detection_test.py +++ b/tests-unit/comfy_test/model_detection_test.py @@ -73,6 +73,24 @@ def _make_flux_schnell_comfyui_sd(): return sd +def _make_seedvr2_7b_separate_mm_sd(): + return { + "blocks.35.mlp.vid.proj_in.weight": torch.empty(1, 3072), + } + + +def _make_seedvr2_7b_shared_mm_sd(): + return { + "blocks.35.mlp.all.proj_in_gate.weight": torch.empty(1, 1), + } + + +def _make_seedvr2_3b_shared_mm_sd(): + return { + "blocks.31.mlp.all.proj_in_gate.weight": torch.empty(1, 1), + } + + class TestModelDetection: """Verify that first-match model detection selects the correct model based on list ordering and unet_config specificity.""" @@ -125,6 +143,45 @@ class TestModelDetection: assert model_config is not None assert type(model_config).__name__ == "FluxSchnell" + def test_seedvr2_7b_separate_mm_detection_config(self): + sd = _make_seedvr2_7b_separate_mm_sd() + unet_config = detect_unet_config(sd, "") + + assert unet_config is not None + assert unet_config["image_model"] == "seedvr2" + assert unet_config["vid_dim"] == 3072 + assert unet_config["heads"] == 24 + assert unet_config["num_layers"] == 36 + assert unet_config["mm_layers"] == 36 + assert unet_config["mlp_type"] == "normal" + assert unet_config["rope_type"] == "rope3d" + assert unet_config["rope_dim"] == 64 + + def test_seedvr2_7b_shared_mm_detection_config(self): + sd = _make_seedvr2_7b_shared_mm_sd() + unet_config = detect_unet_config(sd, "") + + assert unet_config is not None + assert unet_config["image_model"] == "seedvr2" + assert unet_config["vid_dim"] == 3072 + assert unet_config["heads"] == 24 + assert unet_config["num_layers"] == 36 + assert unet_config["mm_layers"] == 10 + assert unet_config["mlp_type"] == "swiglu" + assert unet_config["rope_type"] == "rope3d" + assert unet_config["rope_dim"] == 64 + + def test_seedvr2_3b_shared_mm_detection_config(self): + sd = _make_seedvr2_3b_shared_mm_sd() + unet_config = detect_unet_config(sd, "") + + assert unet_config is not None + assert unet_config["image_model"] == "seedvr2" + assert unet_config["vid_dim"] == 2560 + assert unet_config["heads"] == 20 + assert unet_config["num_layers"] == 32 + assert unet_config["mlp_type"] == "swiglu" + def test_unet_config_and_required_keys_combination_is_unique(self): """Each model in the registry must have a unique combination of ``unet_config`` and ``required_keys``. If two models share the same diff --git a/tests-unit/comfy_test/test_seedvr2_dtype.py b/tests-unit/comfy_test/test_seedvr2_dtype.py new file mode 100644 index 000000000..f03c0406c --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr2_dtype.py @@ -0,0 +1,49 @@ +import torch + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +import comfy.sd +import comfy.supported_models +import comfy.ldm.seedvr.model as seedvr_model +import comfy.ldm.seedvr.vae as seedvr_vae + + +def test_seedvr2_fp16_manual_cast_only_for_bf16_device(monkeypatch): + bf16_device = object() + fp16_device = object() + + monkeypatch.setattr( + comfy.supported_models.comfy.model_management, + "should_use_bf16", + lambda device=None: device is bf16_device, + ) + + bf16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"}) + bf16_config.set_inference_dtype(torch.float16, None, device=bf16_device) + assert bf16_config.manual_cast_dtype is torch.bfloat16 + + fp16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"}) + fp16_config.set_inference_dtype(torch.float16, None, device=fp16_device) + assert fp16_config.manual_cast_dtype is None + + +def test_seedvr2_text_conditioning_accepts_cfg1_single_branch(): + context = torch.arange(6, dtype=torch.float32).reshape(1, 3, 2) + + txt, txt_shape = seedvr_model.NaDiT._resolve_text_conditioning(object(), context, [0]) + + torch.testing.assert_close(txt, context.squeeze(0)) + torch.testing.assert_close(txt_shape, torch.tensor([[3]], device=context.device)) + + +def test_seedvr2_vae_decode_memory_covers_full_frame_lab_transfer(): + wrapper = seedvr_vae.VideoAutoencoderKLWrapper.__new__(seedvr_vae.VideoAutoencoderKLWrapper) + estimate = wrapper.comfy_memory_used_decode((1, 16, 26, 120, 160)) + old_estimate = 16 * 120 * 160 * (4 * 8 * 8) * 2 + + assert estimate == 101 * 960 * 1280 * 160 + assert estimate > 15 * 1024 ** 3 + assert estimate > old_estimate * 100 diff --git a/tests-unit/comfy_test/test_seedvr2_internals.py b/tests-unit/comfy_test/test_seedvr2_internals.py new file mode 100644 index 000000000..dd3121428 --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr2_internals.py @@ -0,0 +1,216 @@ +"""Consolidated SeedVR2 internals regression tests. + +Sources (all merged verbatim, helper names disambiguated where colliding): + + * GroupNorm limit gate — causal_norm_wrapper at vae.py:509 must compare + memory_occupy against get_norm_limit(), not float('inf'). + * SeedVR2 variable-length attention split-loop contract. + +Pre-import CPU-only guard is required because comfy.ldm.seedvr.model and +comfy.ldm.modules.attention transitively pull in comfy.model_management, +which probes torch.cuda.current_device() at import time unless args.cpu is +set first. +""" + +from __future__ import annotations + +from unittest.mock import patch + +import pytest +import torch + +from comfy.cli_args import args + +if not torch.cuda.is_available(): + args.cpu = True + +import comfy.ldm.seedvr.model as seedvr_model # noqa: E402 +import comfy.ldm.seedvr.vae as vae_mod # noqa: E402 +import comfy.ldm.modules.attention as attention # noqa: E402 +import comfy.ops as comfy_ops # noqa: E402 +from comfy.ldm.seedvr.vae import ( # noqa: E402 + causal_norm_wrapper, + set_norm_limit, +) +from comfy.ldm.seedvr.attention import var_attention_optimized_split # noqa: E402 + + +# --------------------------------------------------------------------------- +# GroupNorm limit tests (test_seedvr_groupnorm_limit.py) +# --------------------------------------------------------------------------- + +_NUM_CHANNELS = 8 +_NUM_GROUPS = 4 +_TENSOR_SHAPE = (1, 8, 2, 4, 4) + +_GROUPNORM_SUBCLASSES = [ + pytest.param(comfy_ops.disable_weight_init.GroupNorm, id="disable_weight_init"), + pytest.param(comfy_ops.manual_cast.GroupNorm, id="manual_cast"), +] + + +@pytest.mark.parametrize("groupnorm_cls", _GROUPNORM_SUBCLASSES) +def test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path(groupnorm_cls): + real_group_norm = vae_mod.F.group_norm + set_norm_limit(1e-9) + try: + gn = groupnorm_cls(num_channels=_NUM_CHANNELS, num_groups=_NUM_GROUPS) + gn.eval() + + forward_hook_calls = [] + + def _hook(module, inputs, output): + forward_hook_calls.append(tuple(inputs[0].shape)) + + spy_calls = [] + + def _group_norm_spy(input_tensor, num_groups_arg, *args, **kwargs): + spy_calls.append({"num_groups": int(num_groups_arg)}) + return real_group_norm(input_tensor, num_groups_arg, *args, **kwargs) + + handle = gn.register_forward_hook(_hook) + try: + with patch.object(vae_mod.F, "group_norm", side_effect=_group_norm_spy): + out_tensor = causal_norm_wrapper(gn, torch.randn(*_TENSOR_SHAPE)) + finally: + handle.remove() + + full_calls = len(forward_hook_calls) + chunked_calls = sum(1 for entry in spy_calls if entry["num_groups"] < _NUM_GROUPS) + + assert tuple(int(s) for s in out_tensor.shape) == _TENSOR_SHAPE + assert full_calls == 0, ( + f"low-limit GroupNorm gate must NOT take the full-forward path; got full_calls={full_calls}" + ) + assert chunked_calls > 0, ( + f"low-limit GroupNorm gate must take the chunked path; got chunked_calls={chunked_calls}" + ) + finally: + set_norm_limit(None) + + +# --------------------------------------------------------------------------- +# SeedVR2 var_attention split-loop tests +# --------------------------------------------------------------------------- + +def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypatch): + dim = 8 + heads = 2 + head_dim = 4 + attn = seedvr_model.NaSwinAttention( + vid_dim=dim, + txt_dim=dim, + heads=heads, + head_dim=head_dim, + qk_bias=False, + qk_norm=seedvr_model.CustomRMSNorm, + qk_norm_eps=1e-6, + rope_type=None, + rope_dim=head_dim, + shared_weights=False, + window=(2, 1, 1), + window_method="720pwin_by_size_bysize", + version=True, + device="cpu", + dtype=torch.float32, + operations=comfy_ops.disable_weight_init, + ) + generator = torch.Generator(device="cpu").manual_seed(11) + vid = torch.randn(8, dim, generator=generator) + txt = torch.randn(3, dim, generator=generator) + vid_shape = torch.tensor([[2, 2, 2]], dtype=torch.long) + txt_shape = torch.tensor([[3]], dtype=torch.long) + calls = [] + + def fake_optimized_var_attention(**kwargs): + calls.append(kwargs) + return kwargs["q"] + + monkeypatch.setattr(seedvr_model, "optimized_var_attention", fake_optimized_var_attention) + + vid_out, txt_out = attn(vid, txt, vid_shape, txt_shape, seedvr_model.Cache(disable=True)) + + assert tuple(vid_out.shape) == (8, dim) + assert tuple(txt_out.shape) == (3, dim) + assert len(calls) == 1 + call = calls[0] + assert tuple(call["q"].shape) == (14, heads, head_dim) + assert tuple(call["k"].shape) == (14, heads, head_dim) + assert tuple(call["v"].shape) == (14, heads, head_dim) + assert call["heads"] == heads + assert call["skip_reshape"] is True + assert call["skip_output_reshape"] is True + torch.testing.assert_close( + call["cu_seqlens_q"], + torch.tensor([0, 7, 14], dtype=torch.int32), + rtol=0, + atol=0, + ) + torch.testing.assert_close( + call["cu_seqlens_k"], + torch.tensor([0, 7, 14], dtype=torch.int32), + rtol=0, + atol=0, + ) + + +def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatch): + heads = 2 + head_dim = 3 + q = torch.arange(30, dtype=torch.float32).reshape(5, heads, head_dim) + k = q + 100 + v = q + 200 + cu = torch.tensor([0, 2, 5], dtype=torch.int32) + calls = [] + + def fake_optimized_attention(q_arg, k_arg, v_arg, heads_arg, **kwargs): + calls.append( + { + "q_shape": tuple(q_arg.shape), + "k_shape": tuple(k_arg.shape), + "v_shape": tuple(v_arg.shape), + "heads": heads_arg, + "kwargs": kwargs, + } + ) + return q_arg + v_arg + + monkeypatch.setattr(attention, "optimized_attention", fake_optimized_attention) + + out = var_attention_optimized_split( + q, + k, + v, + heads, + cu, + cu, + skip_reshape=True, + skip_output_reshape=True, + ) + + assert tuple(out.shape) == (5, heads, head_dim) + assert len(calls) == 2 + assert calls[0]["q_shape"] == (1, heads, 2, head_dim) + assert calls[1]["q_shape"] == (1, heads, 3, head_dim) + assert all(call["heads"] == heads for call in calls) + assert all(call["kwargs"]["skip_reshape"] is True for call in calls) + assert all(call["kwargs"]["skip_output_reshape"] is True for call in calls) + torch.testing.assert_close(out, q + v, rtol=0, atol=0) + + +def test_var_attention_optimized_split_rejects_bad_offsets(): + q = torch.randn(5, 2, 3) + cu_bad = torch.tensor([0, 2, 6], dtype=torch.int32) + cu_ok = torch.tensor([0, 2, 5], dtype=torch.int32) + + with pytest.raises(ValueError, match="cu_seqlens_q does not match token count"): + var_attention_optimized_split( + q, + q, + q, + 2, + cu_bad, + cu_ok, + skip_reshape=True, + skip_output_reshape=True, + ) diff --git a/tests-unit/comfy_test/test_seedvr2_model.py b/tests-unit/comfy_test/test_seedvr2_model.py new file mode 100644 index 000000000..feae2211f --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr2_model.py @@ -0,0 +1,307 @@ +"""Consolidated SeedVR2 model/graph/forward regression tests. + +Merged from: +- seedvr_model_test.py +- test_seedvr_7b_final_block_text_path.py +- test_seedvr_forward_no_device_cast.py +- test_seedvr_latent_format.py +- test_seedvr2_vae_graph_boundaries.py +""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import torch +from torch import nn + +from comfy.cli_args import args + +if not torch.cuda.is_available(): + args.cpu = True + +import comfy # noqa: E402 +import comfy.latent_formats # noqa: E402 +import comfy.ldm.seedvr.model # noqa: E402 +import comfy.ldm.seedvr.model as seedvr_model # noqa: E402 +import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402 +import comfy.model_management # noqa: E402 +import comfy.sample # noqa: E402 +import comfy.sd as sd_mod # noqa: E402 +import nodes as nodes_mod # noqa: E402 +from comfy.ldm.seedvr.model import NaDiT # noqa: E402 + + +# --------------------------------------------------------------------------- +# Helpers from seedvr_model_test.py +# --------------------------------------------------------------------------- + + +def _make_standin(positive_conditioning): + class _StandIn(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer( + "positive_conditioning", positive_conditioning + ) + + _resolve_text_conditioning = NaDiT._resolve_text_conditioning + + return _StandIn() + + +# --------------------------------------------------------------------------- +# Helpers from test_seedvr_7b_final_block_text_path.py +# --------------------------------------------------------------------------- + + +class _StubModule(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + +def _capture_last_layer_flags(monkeypatch, vid_dim: int, txt_in_dim: int) -> list[bool]: + flags = [] + + class _Block(_StubModule): + def __init__(self, *args, **kwargs): + flags.append(kwargs["is_last_layer"]) + super().__init__() + + monkeypatch.setattr(seedvr_model, "NaPatchIn", _StubModule) + monkeypatch.setattr(seedvr_model, "NaPatchOut", _StubModule) + monkeypatch.setattr(seedvr_model, "TimeEmbedding", _StubModule) + monkeypatch.setattr(seedvr_model, "NaMMSRTransformerBlock", _Block) + + seedvr_model.NaDiT( + norm_eps=1e-5, + num_layers=4, + mlp_type="normal", + vid_dim=vid_dim, + txt_in_dim=txt_in_dim, + heads=24, + mm_layers=3, + ) + + return flags + + +# --------------------------------------------------------------------------- +# Helpers from test_seedvr_latent_format.py +# --------------------------------------------------------------------------- + + +class _Model: + def __init__(self, latent_format): + self._latent_format = latent_format + + def get_model_object(self, name): + assert name == "latent_format" + return self._latent_format + + +# --------------------------------------------------------------------------- +# Helpers from test_seedvr2_vae_graph_boundaries.py +# --------------------------------------------------------------------------- + + +class _Patcher: + def get_free_memory(self, device): + return 1024 * 1024 * 1024 + + +class _EncodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper): + def __init__(self, encoded): + nn.Module.__init__(self) + self.encoded = encoded + self.spatial_downsample_factor = 8 + self.temporal_downsample_factor = 4 + self.seen = [] + + def encode(self, x): + self.seen.append(tuple(x.shape)) + return self.encoded.to(device=x.device, dtype=x.dtype) + + +class _DecodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper): + def __init__(self): + nn.Module.__init__(self) + self.spatial_downsample_factor = 8 + self.temporal_downsample_factor = 4 + self.calls = [] + + def decode(self, z, seedvr2_tiling=None): + self.calls.append({"shape": tuple(z.shape), "seedvr2_tiling": seedvr2_tiling}) + if z.ndim == 4: + b, tc, h, w = z.shape + t = tc // 16 + else: + b, _, t, h, w = z.shape + return torch.zeros(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device) + + +def _make_vae(wrapper): + vae = sd_mod.VAE.__new__(sd_mod.VAE) + vae.first_stage_model = wrapper + vae.device = torch.device("cpu") + vae.output_device = torch.device("cpu") + vae.vae_dtype = torch.float32 + vae.latent_channels = 16 + vae.latent_dim = 3 + vae.downscale_ratio = (lambda a: max(0, (a + 3) // 4), 8, 8) + vae.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) + vae.output_channels = 3 + vae.disable_offload = True + vae.extra_1d_channel = None + vae.crop_input = False + vae.not_video = False + vae.patcher = _Patcher() + vae.process_input = lambda image: image + vae.process_output = lambda image: image.add(1.0).div(2.0).clamp(0.0, 1.0) + vae.vae_output_dtype = lambda: torch.float32 + vae.memory_used_encode = lambda shape, dtype: 1 + vae.memory_used_decode = lambda shape, dtype: 1 + vae.throw_exception_if_invalid = lambda: None + vae.vae_encode_crop_pixels = lambda pixels: pixels + vae.spacial_compression_decode = lambda: 8 + vae.temporal_compression_decode = lambda: 4 + return vae + + +# --------------------------------------------------------------------------- +# Tests from seedvr_model_test.py +# --------------------------------------------------------------------------- + + +def test_missing_context_falls_back_to_positive_buffer(): + """``context is None`` falls back to the registered ``positive_conditioning`` buffer and runs to completion.""" + pos_buffer = torch.full((58, 5120), 7.0) + standin = _make_standin(pos_buffer) + txt, txt_shape = standin._resolve_text_conditioning(None) + assert txt.shape == (58, 5120) + assert (txt == 7.0).all(), ( + "fallback path must use the positive_conditioning buffer " + "verbatim, not a zero tensor" + ) + assert txt_shape.shape == (1, 1) + assert txt_shape[0, 0].item() == 58 + + +# --------------------------------------------------------------------------- +# Tests from test_seedvr_7b_final_block_text_path.py +# --------------------------------------------------------------------------- + + +def test_seedvr2_7b_keeps_final_block_text_path(monkeypatch): + assert _capture_last_layer_flags(monkeypatch, vid_dim=3072, txt_in_dim=3072) == [ + False, + False, + False, + False, + ] + + +def test_seedvr2_7b_rope3d_matches_wrapper_oracle(): + rope = seedvr_model.get_na_rope("rope3d", dim=64) + generator = torch.Generator(device="cpu").manual_seed(0) + q = torch.randn(4, 2, 128, generator=generator) + k = torch.randn(4, 2, 128, generator=generator) + shape = torch.tensor([[1, 2, 2]], dtype=torch.long) + freqs = rope.get_axial_freqs(1, 2, 2).reshape(4, -1) + + expected_q = seedvr_model._apply_seedvr2_rotary_emb( + freqs, + q.permute(1, 0, 2).float(), + ).to(q.dtype).permute(1, 0, 2) + expected_k = seedvr_model._apply_seedvr2_rotary_emb( + freqs, + k.permute(1, 0, 2).float(), + ).to(k.dtype).permute(1, 0, 2) + + actual_q, actual_k = rope(q.clone(), k.clone(), shape, seedvr_model.Cache(disable=True)) + + torch.testing.assert_close(actual_q, expected_q, rtol=0, atol=0) + torch.testing.assert_close(actual_k, expected_k, rtol=0, atol=0) + + +# --------------------------------------------------------------------------- +# Tests from test_seedvr_latent_format.py +# --------------------------------------------------------------------------- + + +def test_seedvr2_latent_format_uses_16_channels_without_3d_empty_latent_expansion(): + latent_format = comfy.latent_formats.SeedVR2() + latent_image = torch.zeros(1, 1, 4, 5) + + fixed = comfy.sample.fix_empty_latent_channels(_Model(latent_format), latent_image) + + assert latent_format.latent_channels == 16 + assert latent_format.latent_dimensions == 2 + assert fixed.shape == (1, 16, 4, 5) + + +# --------------------------------------------------------------------------- +# Tests from test_seedvr2_vae_graph_boundaries.py +# --------------------------------------------------------------------------- + + +def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeypatch): + monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None) + + encoded = torch.full((1, 16, 2, 4, 5), 2.0) + vae = _make_vae(_EncodeWrapper(encoded)) + pixels = torch.zeros(1, 5, 32, 40, 3) + + node_output = nodes_mod.VAEEncode().encode(vae, pixels)[0] + node_latent = node_output["samples"] + assert set(node_output) == {"samples"} + assert tuple(node_latent.shape) == (1, 16, 2, 4, 5) + assert node_latent.dtype == torch.float32 + assert node_latent.stride()[-1] == 1 + assert torch.equal(node_latent, torch.full_like(node_latent, 2.0 * 0.9152)) + + tiled = torch.full((1, 16, 2, 4, 5), 3.0) + monkeypatch.setattr(seedvr_vae_mod, "tiled_vae", MagicMock(return_value=tiled)) + tiled_output = nodes_mod.VAEEncodeTiled().encode( + vae, + pixels, + tile_size=512, + overlap=64, + temporal_size=16, + temporal_overlap=4, + )[0] + tiled_latent = tiled_output["samples"] + assert set(tiled_output) == {"samples"} + assert tuple(tiled_latent.shape) == (1, 16, 2, 4, 5) + assert tiled_latent.dtype == torch.float32 + assert torch.equal(tiled_latent, torch.full_like(tiled_latent, 3.0 * 0.9152)) + + +def test_vaedecode_tiled_spatial_applies_temporal_discarded(monkeypatch): + monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None) + vae = _make_vae(_DecodeWrapper()) + + nodes_mod.VAEDecodeTiled().decode( + vae, + {"samples": torch.zeros(1, 16, 2, 4, 5)}, + tile_size=512, + overlap=64, + temporal_size=16, + temporal_overlap=4, + ) + + # Spatial inputs flow through; temporal inputs are discarded — SeedVR2 owns + # temporal via the MemoryState causal cache, so VAEDecodeTiled's temporal + # knobs are no-ops at the wrapper. + assert vae.first_stage_model.calls == [ + { + "shape": (1, 16, 2, 4, 5), + "seedvr2_tiling": { + "enable_tiling": True, + "tile_size": (512, 512), + "tile_overlap": (64, 64), + "temporal_size": 0, + "temporal_overlap": 0, + }, + } + ] From bed0cd2b8c02c686ecd9bab04c12fe6aace2ba0f Mon Sep 17 00:00:00 2001 From: John Pollock Date: Thu, 11 Jun 2026 10:40:49 -0500 Subject: [PATCH 05/12] Add SeedVR2 VAE coverage --- .../comfy_test/seedvr_vae_forward_test.py | 86 +++++ .../comfy_test/test_seedvr2_vae_decode.py | 91 +++++ .../comfy_test/test_seedvr2_vae_tiled.py | 348 ++++++++++++++++++ 3 files changed, 525 insertions(+) create mode 100644 tests-unit/comfy_test/seedvr_vae_forward_test.py create mode 100644 tests-unit/comfy_test/test_seedvr2_vae_decode.py create mode 100644 tests-unit/comfy_test/test_seedvr2_vae_tiled.py diff --git a/tests-unit/comfy_test/seedvr_vae_forward_test.py b/tests-unit/comfy_test/seedvr_vae_forward_test.py new file mode 100644 index 000000000..d4af4c2b1 --- /dev/null +++ b/tests-unit/comfy_test/seedvr_vae_forward_test.py @@ -0,0 +1,86 @@ +"""Regression: ``comfy.ldm.seedvr.vae.VideoAutoencoderKL.forward`` must +honor the actual tensor/tuple return contract of ``encode()`` and +``decode_()`` and must NOT dereference diffusers-style ``.latent_dist`` +or ``.sample`` attributes on those returns. + +The pre-fix body raised ``AttributeError: 'Tensor' object has no +attribute 'latent_dist'`` for ``mode in {"encode", "all"}`` and +``AttributeError: 'VideoAutoencoderKL' object has no attribute 'decode'`` +for ``mode == "decode"`` (the class only defines ``decode_`` with a +trailing underscore). The post-fix body unwraps the optional one-element +tuple shape that ``return_dict=False`` produces and returns the tensor +directly. + +Tests construct a stub subclass of ``VideoAutoencoderKL`` that bypasses +the heavy ``__init__`` via ``torch.nn.Module.__init__(self)`` and +overrides ``encode``/``decode_`` with known tensors so the contract can +be probed without loading any real VAE weights. +""" + +import torch +import torch.nn as nn + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +from comfy.ldm.seedvr.vae import VideoAutoencoderKL # noqa: E402 + + +_LATENT_SHAPE = (1, 16, 2, 2, 2) +_DECODED_SHAPE = (1, 3, 5, 16, 16) +_INPUT_ENCODE_SHAPE = (1, 3, 5, 16, 16) +_INPUT_DECODE_SHAPE = (1, 16, 2, 2, 2) + + +class _StubVAE(VideoAutoencoderKL): + def __init__(self): + nn.Module.__init__(self) + self._encode_out = torch.zeros(*_LATENT_SHAPE) + self._decode_out = torch.zeros(*_DECODED_SHAPE) + + def encode(self, x, return_dict=True): + return self._encode_out + + def decode_(self, z, return_dict=True): + return self._decode_out + + +def test_forward_encode_returns_tensor(): + vae = _StubVAE() + x = torch.zeros(*_INPUT_ENCODE_SHAPE) + result = vae.forward(x, mode="encode") + assert type(result) is torch.Tensor + assert result.shape == torch.Size(_LATENT_SHAPE) + + +def test_forward_decode_returns_tensor(): + vae = _StubVAE() + z = torch.zeros(*_INPUT_DECODE_SHAPE) + result = vae.forward(z, mode="decode") + assert type(result) is torch.Tensor + assert result.shape == torch.Size(_DECODED_SHAPE) + + +class _TupleReturningStubVAE(VideoAutoencoderKL): + """Stub whose ``encode``/``decode_`` return the ``(tensor,)`` tuple of ``return_dict=False``, exercising the unwrap branch of ``VideoAutoencoderKL.forward``.""" + + def __init__(self): + nn.Module.__init__(self) + self._encode_tensor = torch.zeros(*_LATENT_SHAPE) + self._decode_tensor = torch.zeros(*_DECODED_SHAPE) + + def encode(self, x, return_dict=True): + return (self._encode_tensor,) + + def decode_(self, z, return_dict=True): + return (self._decode_tensor,) + + +def test_forward_all_unwraps_one_tuple_at_each_step(): + vae = _TupleReturningStubVAE() + x = torch.zeros(*_INPUT_ENCODE_SHAPE) + result = vae.forward(x, mode="all") + assert type(result) is torch.Tensor + assert result.shape == torch.Size(_DECODED_SHAPE) diff --git a/tests-unit/comfy_test/test_seedvr2_vae_decode.py b/tests-unit/comfy_test/test_seedvr2_vae_decode.py new file mode 100644 index 000000000..ea9f978f3 --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr2_vae_decode.py @@ -0,0 +1,91 @@ +from unittest.mock import patch + +import pytest +import torch +import torch.nn as nn + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +import comfy.ldm.seedvr.vae as vae_mod # noqa: E402 +from comfy_extras import nodes_seedvr # noqa: E402 + + +def _make_wrapper() -> vae_mod.VideoAutoencoderKLWrapper: + wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__( + vae_mod.VideoAutoencoderKLWrapper + ) + nn.Module.__init__(wrapper) + return wrapper + + +def _fingerprint_decode_(self, z, return_dict=True): + b = int(z.shape[0]) + t = int(z.shape[2]) + h = int(z.shape[3]) + w = int(z.shape[4]) + out = torch.empty(b, 3, t, h * 8, w * 8) + for batch_idx in range(b): + out[batch_idx].fill_(float(batch_idx + 1)) + return out + + +def _decode_with_patches(wrapper, z): + with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _fingerprint_decode_): + return wrapper.decode(z) + + +def test_decode_b2_t3_multi_frame_batch_unchanged(): + wrapper = _make_wrapper() + + out = _decode_with_patches(wrapper, torch.zeros(2, 16 * 3, 2, 2)) + + assert tuple(out.shape) == (2, 3, 3, 16, 16) + + +class _Wrapper(vae_mod.VideoAutoencoderKLWrapper): + def __init__(self): + nn.Module.__init__(self) + self.calls = [] + + def parameters(self): + return iter([torch.nn.Parameter(torch.zeros(()))]) + +def _decode_stub(self, latent): + self.calls.append(tuple(latent.shape)) + return torch.zeros(latent.shape[0], 3, latent.shape[2], latent.shape[3] * 8, latent.shape[4] * 8) + + +def test_seedvr2_wrapper_decode_accepts_5d_channel_first_latents_without_preprocessor_state(): + wrapper = _Wrapper() + + with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub): + out = wrapper.decode(torch.zeros(1, 16, 2, 4, 5)) + + assert tuple(out.shape) == (1, 3, 2, 32, 40) + assert wrapper.calls == [(1, 16, 2, 4, 5)] + + +def test_seedvr2_wrapper_decode_rejects_wrong_rank_latents(): + wrapper = _Wrapper() + + with pytest.raises(RuntimeError, match=r"latent input must be 4-D collapsed .* or 5-D"): + wrapper.decode(torch.zeros(1, 16, 4)) + + +def _t_padded(t_in: int) -> int: + if t_in == 1: + return 1 + if t_in <= 4: + return 5 + if (t_in - 1) % 4 == 0: + return t_in + return t_in + (4 - ((t_in - 1) % 4)) + + +@pytest.mark.parametrize("t_in", [1, 5, 9]) +def test_t_padded_matches_cut_videos(t_in): + dummy = torch.zeros(1, t_in, 1, 1, 1) + assert nodes_seedvr.cut_videos(dummy).shape[1] == _t_padded(t_in) diff --git a/tests-unit/comfy_test/test_seedvr2_vae_tiled.py b/tests-unit/comfy_test/test_seedvr2_vae_tiled.py new file mode 100644 index 000000000..ced2fe34f --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr2_vae_tiled.py @@ -0,0 +1,348 @@ +from contextlib import ExitStack +from unittest.mock import MagicMock, patch + +import torch +import torch.nn as nn + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +import comfy.ldm.seedvr.vae as vae_mod # noqa: E402 +import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402 +import comfy.sd as sd_mod # noqa: E402 +from comfy.ldm.seedvr.vae import MemoryState, tiled_vae # noqa: E402 + + +# --------------------------------------------------------------------------- +# From test_seedvr_vae_tiled_decode_latent_min_size_override.py +# --------------------------------------------------------------------------- + + +def test_runtime_decode_zero_temporal_size_disables_slicing_for_call(): + from comfy.ldm.seedvr.vae import MemoryState, VideoAutoencoderKL, tiled_vae + + class StubVAEModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.slicing_latent_min_size = 2 + self.spatial_downsample_factor = 8 + self.temporal_downsample_factor = 4 + self.device = torch.device("cpu") + self.use_slicing = True + self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32)) + self.decode_min_sizes = [] + self.memory_states = [] + + def decode_(self, t_chunk): + self.decode_min_sizes.append(self.slicing_latent_min_size) + return VideoAutoencoderKL.slicing_decode(self, t_chunk) + + def _decode(self, z, memory_state=MemoryState.DISABLED): + self.memory_states.append(memory_state) + b, c, d, h, w = z.shape + return torch.zeros((b, 3, d, h * 8, w * 8), dtype=z.dtype) + + vae = StubVAEModel() + z = torch.zeros((1, 16, 5, 8, 8), dtype=torch.float32) + + tiled_vae( + z, + vae, + tile_size=(64, 64), + tile_overlap=(0, 0), + temporal_size=0, + temporal_overlap=0, + encode=False, + ) + + assert vae.decode_min_sizes == [5] + assert vae.memory_states == [MemoryState.DISABLED] + assert vae.slicing_latent_min_size == 2 + + +# --------------------------------------------------------------------------- +# From test_seedvr_vae_tiled_encode_runt_slice_override.py +# --------------------------------------------------------------------------- + + +def test_zero_temporal_size_preserves_min_size_when_encode_raises(): + from comfy.ldm.seedvr.vae import tiled_vae + + class RaisingVAEModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.slicing_sample_min_size = 4 + self.spatial_downsample_factor = 8 + self.temporal_downsample_factor = 4 + self.device = torch.device("cpu") + self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32)) + + def encode(self, t_chunk): + raise RuntimeError("simulated encode failure") + + vae = RaisingVAEModel() + x = torch.zeros((1, 3, 12, 64, 64), dtype=torch.float32) + + raised = False + try: + tiled_vae( + x, + vae, + tile_size=(64, 64), + tile_overlap=(0, 0), + temporal_size=0, + temporal_overlap=0, + encode=True, + ) + except RuntimeError as exc: + if "simulated encode failure" not in str(exc): + raise + raised = True + + assert raised + assert vae.slicing_sample_min_size == 4 + + +# --------------------------------------------------------------------------- +# From test_seedvr_vae_tiled_temporal_slicing.py +# --------------------------------------------------------------------------- + + +class _SlicingDecodeVAE(nn.Module): + def __init__(self, slicing_latent_min_size): + super().__init__() + self.slicing_latent_min_size = slicing_latent_min_size + self.spatial_downsample_factor = 8 + self.temporal_downsample_factor = 4 + self.device = torch.device("cpu") + self.use_slicing = True + self._dummy = nn.Parameter(torch.zeros(1, dtype=torch.float32)) + self.decode_min_sizes = [] + self.memory_states = [] + + def decode_(self, z): + self.decode_min_sizes.append(self.slicing_latent_min_size) + return vae_mod.VideoAutoencoderKL.slicing_decode(self, z) + + def _decode(self, z, memory_state=MemoryState.DISABLED): + self.memory_states.append(memory_state) + x = z[:, :1].repeat( + 1, + 3, + 1, + self.spatial_downsample_factor, + self.spatial_downsample_factor, + ) + return x + + +def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size(): + vae = _SlicingDecodeVAE(slicing_latent_min_size=2) + z = torch.arange(1 * 16 * 5 * 8 * 8, dtype=torch.float32).reshape(1, 16, 5, 8, 8) + + tiled_vae( + z, + vae, + tile_size=(64, 64), + tile_overlap=(0, 0), + temporal_size=12, + temporal_overlap=4, + encode=False, + ) + + assert vae.decode_min_sizes == [2] + assert vae.memory_states == [MemoryState.INITIALIZING, MemoryState.ACTIVE] + assert vae.slicing_latent_min_size == 2 + + wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__( + vae_mod.VideoAutoencoderKLWrapper + ) + nn.Module.__init__(wrapper) + seedvr2_tiling = { + "enable_tiling": True, + "tile_size": (64, 64), + "tile_overlap": (0, 0), + "temporal_size": 8, + "temporal_overlap": 7, + } + + captured = {} + + def _fake_tiled_vae(latent, model, **kwargs): + captured.update(kwargs) + return torch.zeros(1, 3, 1, 16, 16) + + with patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae): + wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling=seedvr2_tiling) + + assert captured["temporal_overlap"] == 7 + + +# --------------------------------------------------------------------------- +# From test_vae_decode_tiled_dispatcher_seedvr2_4d.py +# --------------------------------------------------------------------------- + + +def _force_oom(*a, **k): + raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test") + + +def _make_vae(first_stage_model, latent_channels, latent_dim): + vae = sd_mod.VAE.__new__(sd_mod.VAE) + vae.first_stage_model = first_stage_model + vae.patcher = MagicMock() + vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024) + vae.device = vae.output_device = torch.device("cpu") + vae.vae_dtype = torch.float32 + vae.disable_offload = True + vae.extra_1d_channel = None + vae.upscale_ratio = vae.downscale_ratio = 8 + vae.upscale_index_formula = vae.downscale_index_formula = None + vae.output_channels = 3 + vae.latent_channels = latent_channels + vae.latent_dim = latent_dim + vae.vae_output_dtype = lambda: torch.float32 + vae.spacial_compression_decode = lambda: 8 + vae.process_input = lambda x: x + vae.process_output = lambda x: x + vae.throw_exception_if_invalid = lambda: None + vae.memory_used_decode = lambda *a, **k: 1 + return vae + + +def _dispatch(vae, samples, seedvr2_call, generic_call, patch_wrapper_decode): + mm = sd_mod.model_management + with ExitStack() as stack: + stack.enter_context(patch.object(mm, "raise_non_oom", lambda e: None)) + stack.enter_context(patch.object(mm, "load_models_gpu", lambda *a, **k: None)) + stack.enter_context(patch.object(mm, "soft_empty_cache", lambda: None)) + stack.enter_context(patch.object(sd_mod.VAE, "_decode_tiled_owned", seedvr2_call)) + stack.enter_context(patch.object(sd_mod.VAE, "decode_tiled_", generic_call)) + if patch_wrapper_decode: + stack.enter_context(patch.object( + seedvr_vae_mod.VideoAutoencoderKLWrapper, "decode", + side_effect=_force_oom)) + vae.decode(samples) + + +def test_4d_seedvr2_latent_routes_to_owned_decode_tiled(): + wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__( + seedvr_vae_mod.VideoAutoencoderKLWrapper) + vae = _make_vae(wrapper, latent_channels=16, latent_dim=3) + seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64)) + generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64)) + _dispatch(vae, torch.zeros(1, 16 * 3, 8, 8), seedvr2_call, generic_call, True) + assert seedvr2_call.call_count == 1 + assert generic_call.call_count == 0 + + +def test_4d_non_seedvr2_latent_still_routes_to_generic_decode_tiled(): + first_stage = MagicMock() + first_stage.comfy_handles_tiling = False + first_stage.decode = MagicMock(side_effect=_force_oom) + vae = _make_vae(first_stage, latent_channels=4, latent_dim=2) + seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64)) + generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64)) + _dispatch(vae, torch.zeros(1, 4, 8, 8), seedvr2_call, generic_call, False) + assert generic_call.call_count == 1 + assert seedvr2_call.call_count == 0 + + +# --------------------------------------------------------------------------- +# From test_vae_encode_tiled_fallback_dispatcher_seedvr2.py +# --------------------------------------------------------------------------- + + +def _populate_common_vae_attrs_fallback(vae): + vae.patcher = MagicMock() + vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024) + vae.device = torch.device("cpu") + vae.output_device = torch.device("cpu") + vae.vae_dtype = torch.float32 + vae.disable_offload = True + vae.extra_1d_channel = None + vae.upscale_ratio = 8 + vae.upscale_index_formula = None + vae.output_channels = 3 + vae.latent_channels = 16 + vae.latent_dim = 3 + vae.downscale_ratio = 8 + vae.downscale_index_formula = None + vae.not_video = False + vae.crop_input = False + vae.pad_channel_value = None + + vae.vae_output_dtype = lambda: torch.float32 + vae.spacial_compression_encode = lambda: 8 + vae.process_input = lambda x: x + vae.process_output = lambda x: x + vae.throw_exception_if_invalid = lambda: None + vae.memory_used_encode = lambda *a, **k: 1 + + +def _make_seedvr2_vae_fallback(): + vae = sd_mod.VAE.__new__(sd_mod.VAE) + wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__( + seedvr_vae_mod.VideoAutoencoderKLWrapper + ) + vae.first_stage_model = wrapper + _populate_common_vae_attrs_fallback(vae) + return vae + + +def _make_non_seedvr2_vae_fallback(): + vae = sd_mod.VAE.__new__(sd_mod.VAE) + vae.first_stage_model = MagicMock() + vae.first_stage_model.comfy_handles_tiling = False + _populate_common_vae_attrs_fallback(vae) + return vae + + +def _force_regular_encode_oom(*args, **kwargs): + raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test") + + +def test_seedvr2_3d_routes_to_owned_encode_tiled_on_oom(): + vae = _make_seedvr2_vae_fallback() + pixel_samples = torch.zeros((1, 8, 64, 64, 3)) + + seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) + generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) + + with patch.object(sd_mod.model_management, "raise_non_oom", + lambda e: None), \ + patch.object(sd_mod.model_management, "load_models_gpu", + lambda *a, **k: None), \ + patch.object(sd_mod.model_management, "soft_empty_cache", + lambda: None), \ + patch.object(seedvr_vae_mod.VideoAutoencoderKLWrapper, "encode", + side_effect=_force_regular_encode_oom), \ + patch.object(sd_mod.VAE, "_encode_tiled_owned", seedvr2_call), \ + patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call): + vae.encode(pixel_samples) + + assert seedvr2_call.call_count == 1, ( + f"Expected _encode_tiled_owned to be called once for a SeedVR2 3D " + f"input under OOM fallback; got {seedvr2_call.call_count} calls." + ) + assert generic_call.call_count == 0, ( + f"encode_tiled_3d must NOT be called for a SeedVR2 input; got " + f"{generic_call.call_count} calls." + ) + + +def test_non_seedvr2_encode_tiled_3d_default_overlap_is_concrete(): + vae = _make_non_seedvr2_vae_fallback() + vae.downscale_ratio = (lambda a: max(1, a // 4), 8, 8) + vae.upscale_ratio = (lambda a: a * 4, 8, 8) + generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) + pixel_samples = torch.zeros((1, 8, 64, 64, 3)) + + with patch.object(sd_mod.model_management, "load_models_gpu", + lambda *a, **k: None), \ + patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call): + vae.encode_tiled(pixel_samples) + + assert generic_call.call_args.kwargs["overlap"] == (1, 64, 64) From 7050bdc02be0c76cebd84ecdc4e2047efeefae2c Mon Sep 17 00:00:00 2001 From: John Pollock Date: Thu, 11 Jun 2026 10:41:05 -0500 Subject: [PATCH 06/12] Add SeedVR2 node coverage --- .../test_seedvr2_conditioning.py | 213 ++++++++++++++++++ .../comfy_extras_test/test_seedvr2_nodes.py | 55 +++++ .../test_seedvr2_post_processing.py | 57 +++++ 3 files changed, 325 insertions(+) create mode 100644 tests-unit/comfy_extras_test/test_seedvr2_conditioning.py create mode 100644 tests-unit/comfy_extras_test/test_seedvr2_nodes.py create mode 100644 tests-unit/comfy_extras_test/test_seedvr2_post_processing.py diff --git a/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py b/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py new file mode 100644 index 000000000..2a6e3d430 --- /dev/null +++ b/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py @@ -0,0 +1,213 @@ +"""Consolidated SeedVR2 conditioning and refactor regression tests. + +Merges the prior test_seedvr2_refactor_nodes.py and +test_seedvr_conditioning_hardening.py modules. Refactor tests use the +top-level comfy_extras.nodes_seedvr import; conditioning-hardening tests +use _import_nodes_seedvr_isolated() for sys.modules isolation when +mocking comfy.model_management. +""" + +import importlib +import sys +from unittest.mock import MagicMock + +import pytest +import torch +import torch.nn as nn + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + + +_SENTINEL = object() +_TARGETS = ( + ("comfy.model_management", "comfy"), + ("comfy_extras.nodes_seedvr", "comfy_extras"), +) + + +def _import_nodes_seedvr_isolated(): + """Import comfy_extras.nodes_seedvr with comfy.model_management mocked.""" + priors = [] + for mod_name, parent_name in _TARGETS: + prior_mod = sys.modules.get(mod_name, _SENTINEL) + parent = sys.modules.get(parent_name) + attr = mod_name.split(".")[-1] + prior_attr = ( + getattr(parent, attr, _SENTINEL) if parent is not None else _SENTINEL + ) + priors.append((mod_name, parent_name, attr, prior_mod, prior_attr)) + + mock_mm = MagicMock() + for fn in ( + "xformers_enabled", "xformers_enabled_vae", + "pytorch_attention_enabled", "pytorch_attention_enabled_vae", + "sage_attention_enabled", "flash_attention_enabled", + "is_intel_xpu", + ): + getattr(mock_mm, fn).return_value = False + tv = torch.version.__version__.split(".") + mock_mm.torch_version_numeric = (int(tv[0]), int(tv[1])) + mock_mm.WINDOWS = False + sys.modules["comfy.model_management"] = mock_mm + if sys.modules.get("comfy") is None: + import comfy as _comfy_pkg # noqa: F401 + comfy_pkg = sys.modules.get("comfy") + if comfy_pkg is not None: + setattr(comfy_pkg, "model_management", mock_mm) + nodes_seedvr = sys.modules.get("comfy_extras.nodes_seedvr") or ( + importlib.import_module("comfy_extras.nodes_seedvr") + ) + + def _restore(): + for mod_name, parent_name, attr, prior_mod, prior_attr in priors: + if prior_mod is _SENTINEL: + sys.modules.pop(mod_name, None) + else: + sys.modules[mod_name] = prior_mod + parent = sys.modules.get(parent_name) + if parent is None: + continue + if prior_attr is _SENTINEL: + if hasattr(parent, attr): + delattr(parent, attr) + else: + setattr(parent, attr, prior_attr) + + return nodes_seedvr, _restore + + +class _Rope(nn.Module): + """Minimal RoPE stub exposing a `freqs` parameter.""" + def __init__(self): + super().__init__() + self.freqs = nn.Parameter(torch.zeros(4)) + + +class _Block(nn.Module): + """Minimal transformer block stub holding a `_Rope`.""" + def __init__(self): + super().__init__() + self.rope = _Rope() + + +class _DiffusionModel(nn.Module): + """Stub diffusion model with N blocks and pos/neg conditioning buffers.""" + def __init__(self, n_blocks=3, zero_conditioning=False, conditioning_dtype=torch.float32): + super().__init__() + self.blocks = nn.ModuleList([_Block() for _ in range(n_blocks)]) + pos = torch.zeros if zero_conditioning else torch.ones + self.register_buffer("positive_conditioning", pos((2, 4), dtype=conditioning_dtype)) + self.register_buffer("negative_conditioning", torch.zeros((3, 4), dtype=conditioning_dtype)) + + +class _ModelInner: + """Inner model wrapper exposing `.diffusion_model`.""" + def __init__(self, diffusion_model): + self.diffusion_model = diffusion_model + + +class _ModelPatcher: + """ModelPatcher stub exposing `.model._ModelInner`.""" + def __init__(self, diffusion_model): + self.model = _ModelInner(diffusion_model) + + +def test_seedvr2_conditioning_schema_exposes_model_passthrough_output(): + nodes_seedvr, restore = _import_nodes_seedvr_isolated() + try: + schema = nodes_seedvr.SeedVR2Conditioning.define_schema() + assert [input_item.id for input_item in schema.inputs] == [ + "model", + "vae_conditioning", + ] + assert schema.inputs[1].display_name == "latent" + assert [output.display_name for output in schema.outputs] == [ + "model", + "positive", + "negative", + "latent", + ] + finally: + restore() + + +def test_seedvr2_conditioning_returns_packed_input_latent_deterministically(): + nodes_seedvr, restore = _import_nodes_seedvr_isolated() + try: + diffusion_model = _DiffusionModel() + patcher = _ModelPatcher(diffusion_model) + samples = torch.arange(1, 25, dtype=torch.float32).reshape(1, 2, 3, 2, 2) + vae_conditioning = {"samples": samples} + + _, first_positive, first_negative, first_latent = ( + nodes_seedvr.SeedVR2Conditioning.execute( + patcher, + vae_conditioning, + ) + ) + _, second_positive, second_negative, second_latent = ( + nodes_seedvr.SeedVR2Conditioning.execute( + patcher, + vae_conditioning, + ) + ) + + expected_latent = samples.reshape(1, 6, 2, 2) + channel_last = samples.movedim(1, -1).contiguous() + expected_condition = torch.cat( + [ + channel_last, + torch.ones((*channel_last.shape[:-1], 1)), + ], + dim=-1, + ).movedim(-1, 1).reshape(1, 9, 2, 2) + + assert torch.equal(first_latent["samples"], expected_latent) + assert torch.equal(second_latent["samples"], expected_latent) + assert torch.equal( + first_positive[0][1]["condition"], + expected_condition, + ) + assert torch.equal( + second_positive[0][1]["condition"], + expected_condition, + ) + assert torch.equal( + first_negative[0][1]["condition"], + expected_condition, + ) + assert torch.equal( + second_negative[0][1]["condition"], + expected_condition, + ) + finally: + restore() + + +def test_seedvr2_conditioning_fails_loud_on_zero_buffers(): + nodes_seedvr, restore = _import_nodes_seedvr_isolated() + try: + diffusion_model = _DiffusionModel(zero_conditioning=True) + patcher = _ModelPatcher(diffusion_model) + vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))} + + with pytest.raises(RuntimeError) as excinfo: + nodes_seedvr.SeedVR2Conditioning.execute( + patcher, vae_conditioning, + ) + + message = str(excinfo.value) + assert message.startswith( + nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX + ), ( + "Fail-loud message must use the standard " + "_SEEDVR2_INVALID_MODEL_MSG_PREFIX so callers/log scrapers " + f"can match it. Got: {message!r}" + ) + assert "positive_conditioning" in message + assert "negative_conditioning" in message + finally: + restore() diff --git a/tests-unit/comfy_extras_test/test_seedvr2_nodes.py b/tests-unit/comfy_extras_test/test_seedvr2_nodes.py new file mode 100644 index 000000000..f7d9a4f65 --- /dev/null +++ b/tests-unit/comfy_extras_test/test_seedvr2_nodes.py @@ -0,0 +1,55 @@ +import importlib +import inspect +import sys +from unittest.mock import MagicMock, patch + +import torch + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + + +def test_seedvr_node_signature_matches_schema(): + mock_mm = MagicMock() + mock_mm.xformers_enabled.return_value = False + mock_mm.xformers_enabled_vae.return_value = False + mock_mm.sage_attention_enabled.return_value = False + mock_mm.flash_attention_enabled.return_value = False + + sentinel = object() + prior_cpu = cli_args.cpu + cli_args.cpu = True + prior_module = sys.modules.get("comfy_extras.nodes_seedvr", sentinel) + comfy_pkg = sys.modules.get("comfy") + prior_mm_attr = getattr(comfy_pkg, "model_management", sentinel) if comfy_pkg else sentinel + + with patch.dict(sys.modules, {"comfy.model_management": mock_mm}): + if comfy_pkg is not None: + setattr(comfy_pkg, "model_management", mock_mm) + sys.modules.pop("comfy_extras.nodes_seedvr", None) + try: + nodes_seedvr = importlib.import_module("comfy_extras.nodes_seedvr") + for node_cls in (nodes_seedvr.SeedVR2Preprocess, nodes_seedvr.SeedVR2PostProcessing, nodes_seedvr.SeedVR2Conditioning, nodes_seedvr.SeedVR2ProgressiveSampler): + schema_ids = [i.id for i in node_cls.define_schema().inputs] + exec_params = [ + p for p in inspect.signature(node_cls.execute).parameters.keys() + if p != "cls" + ] + assert schema_ids == exec_params, ( + f"{node_cls.__name__} schema/execute drift: " + f"schema_ids={schema_ids}, exec_params={exec_params}" + ) + finally: + cli_args.cpu = prior_cpu + if prior_module is sentinel: + sys.modules.pop("comfy_extras.nodes_seedvr", None) + else: + sys.modules["comfy_extras.nodes_seedvr"] = prior_module + if comfy_pkg is not None: + if prior_mm_attr is sentinel: + if hasattr(comfy_pkg, "model_management"): + delattr(comfy_pkg, "model_management") + else: + setattr(comfy_pkg, "model_management", prior_mm_attr) diff --git a/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py b/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py new file mode 100644 index 000000000..a27a8f8df --- /dev/null +++ b/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py @@ -0,0 +1,57 @@ +from unittest.mock import patch + +import torch + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +from comfy_extras import nodes_seedvr # noqa: E402 + + +def _schema_ids(items): + return [item.id for item in items] + + +def test_seedvr2_post_processing_schema(): + schema = nodes_seedvr.SeedVR2PostProcessing.define_schema() + + assert _schema_ids(schema.inputs) == ["images", "original_resized_images", "color_correction_method"] + assert schema.inputs[2].options == ["lab", "wavelet", "adain", "none"] + assert schema.inputs[2].default == "lab" + assert schema.outputs[0].get_io_type() == "IMAGE" + + +def test_seedvr2_post_processing_oom_error_uses_color_correction_method(monkeypatch): + decoded = torch.full((1, 3, 4, 4), 0.25) + reference = torch.full((1, 3, 4, 4), 0.75) + + def _lab(content, style): + raise torch.cuda.OutOfMemoryError("CUDA out of memory") + + monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu")) + monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000) + monkeypatch.setattr(nodes_seedvr.comfy.model_management, "soft_empty_cache", lambda: None) + + with patch.object(nodes_seedvr, "lab_color_transfer", _lab): + try: + nodes_seedvr.SeedVR2PostProcessing._color_transfer_chunked( + decoded, reference, torch.device("cpu"), "lab", + ) + except RuntimeError as exc: + assert "color_correction_method=lab" in str(exc) + assert " method=lab" not in str(exc) + else: + raise AssertionError("expected RuntimeError for one-frame LAB OOM") + + +def test_seedvr2_post_processing_unknown_color_correction_method_raises(): + decoded = torch.zeros(1, 2, 4, 4, 3) + original = torch.zeros(1, 2, 4, 4, 3) + try: + nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, "bogus") + except ValueError as exc: + assert "color_correction_method" in str(exc) + else: + raise AssertionError("expected ValueError for unknown color_correction_method") From cfb9c31c99611a61090c97f2216c5e54db490227 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Thu, 11 Jun 2026 10:41:23 -0500 Subject: [PATCH 07/12] Add SeedVR2 sampler coverage --- .../test_seedvr_progressive_sampler.py | 95 +++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 tests-unit/comfy_test/test_seedvr_progressive_sampler.py diff --git a/tests-unit/comfy_test/test_seedvr_progressive_sampler.py b/tests-unit/comfy_test/test_seedvr_progressive_sampler.py new file mode 100644 index 000000000..146b81225 --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr_progressive_sampler.py @@ -0,0 +1,95 @@ +"""Unit tests for ``comfy_extras.nodes_seedvr.SeedVR2ProgressiveSampler``.""" + +from unittest.mock import patch + +import pytest +import torch + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +import comfy.sample # noqa: E402 +import comfy_extras.nodes_seedvr as nodes_seedvr_mod # noqa: E402 +from comfy_extras.nodes_seedvr import SeedVR2ProgressiveSampler # noqa: E402 + +_LAT_C = 16 +_COND_C = 17 + + +def _make_inputs(B: int = 1, T: int = 5, H: int = 8, W: int = 8): + """Build minimal SeedVR2-shaped sampling inputs.""" + samples_5d = torch.arange( + B * _LAT_C * T * H * W, dtype=torch.float32 + ).reshape(B, _LAT_C, T, H, W) + samples = samples_5d.reshape(B, _LAT_C * T, H, W).contiguous() + + cond_5d = torch.arange( + B * _COND_C * T * H * W, dtype=torch.float32 + ).reshape(B, _COND_C, T, H, W) + 10000.0 + cond = cond_5d.reshape(B, _COND_C * T, H, W).contiguous() + + text_pos = torch.zeros(1, 4, 32) + text_neg = torch.zeros(1, 4, 32) + positive = [[text_pos, {"condition": cond.clone()}]] + negative = [[text_neg, {"condition": cond.clone()}]] + latent_image = {"samples": samples} + return latent_image, positive, negative, samples_5d, cond_5d + + +def _identity_fix_empty(model, latent_image, downscale_ratio_spacial=None): + return latent_image + + +def _fingerprinted_prepare_noise(latent_image, seed, batch_inds=None): + """Return a tensor whose values encode ``(seed, position)``.""" + base = torch.arange( + latent_image.numel(), dtype=torch.float32 + ).reshape(latent_image.shape) + return base + float(seed) * 1e6 + + +def test_progressive_sampler_schema_exposes_manual_default_auto_chunking(): + schema = SeedVR2ProgressiveSampler.define_schema() + inputs = {item.id: item for item in schema.inputs} + + assert inputs["chunking_mode"].options == ["manual", "auto"] + assert inputs["chunking_mode"].default == "manual" + + +def test_vram_seed_frames_per_chunk_predicts_4n1_clamped_to_t_pixel(): + """VRAM chunk-size law: seed = nearest 4n+1 to 4*(free_GB - 3), clamped to [1, t_pixel].""" + gib = 1024 ** 3 + seed = nodes_seedvr_mod._seedvr2_vram_seed_frames_per_chunk + assert seed(20 * gib, 65) == 65 # 4*(20-3)=68 -> 4n+1 69 -> clamp to t_pixel 65 + assert seed(6 * gib, 97) == 13 # 4*(6-3)=12 -> nearest 4n+1 13 + assert seed(2 * gib, 97) == 1 # below margin -> floor at 1 + + +@pytest.mark.parametrize("bad_chunk", [0, -1, 2]) +def test_t3_invalid_frames_per_chunk_raises_value_error(bad_chunk): + """``frames_per_chunk`` violating 4n+1 (or <1) must raise ``ValueError`` before any model invocation.""" + latent, pos, neg, _, _ = _make_inputs(T=5) + + sampler_called = {"n": 0} + + def _should_not_be_called(*args, **kwargs): + sampler_called["n"] += 1 + return torch.zeros(1) + + with patch.object(comfy.sample, "sample", + side_effect=_should_not_be_called), \ + patch.object(comfy.sample, "fix_empty_latent_channels", + side_effect=_identity_fix_empty), \ + patch.object(comfy.sample, "prepare_noise", + side_effect=_fingerprinted_prepare_noise): + with pytest.raises(ValueError) as excinfo: + SeedVR2ProgressiveSampler.execute( + model=None, seed=0, steps=2, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=pos, negative=neg, latent=latent, + denoise=1.0, frames_per_chunk=bad_chunk, temporal_overlap=0, + ) + assert str(bad_chunk) in str(excinfo.value) + assert sampler_called["n"] == 0 From e5959653922157361e8bafa21f2d28aa60282025 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Wed, 1 Jul 2026 07:07:41 -0500 Subject: [PATCH 08/12] Remove SeedVR2 VAE memory convolution workaround --- comfy/ldm/seedvr/vae.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index 3996b9103..501896516 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -30,7 +30,6 @@ from comfy.ldm.modules.diffusionmodules.model import vae_attention import math from enum import Enum -from comfy.ops import NVIDIA_MEMORY_CONV_BUG_WORKAROUND import logging import comfy.model_management @@ -597,23 +596,6 @@ class InflatedCausalConv3d(ops.Conv3d): self.memory_limit = value def _conv_forward(self, input, weight, bias, *args, **kwargs): - if (NVIDIA_MEMORY_CONV_BUG_WORKAROUND and - weight.dtype in (torch.float16, torch.bfloat16) and - hasattr(torch.backends.cudnn, 'is_available') and - torch.backends.cudnn.is_available() and - getattr(torch.backends.cudnn, 'enabled', True)): - try: - out = torch.cudnn_convolution( - input, weight, self.padding, self.stride, self.dilation, self.groups, - benchmark=False, deterministic=False, allow_tf32=True - ) - if bias is not None: - out += bias.reshape((1, -1) + (1,) * (out.ndim - 2)) - return out - except RuntimeError: - pass - except NotImplementedError: - pass try: return super()._conv_forward(input, weight, bias, *args, **kwargs) except NotImplementedError: From f437d87155f1744dfb8a418879d2958dc90a10c7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 1 Jul 2026 22:17:51 -0400 Subject: [PATCH 09/12] Cleanups using AGENTS.md --- comfy/ldm/seedvr/attention.py | 7 - comfy/ldm/seedvr/color_fix.py | 13 +- comfy/ldm/seedvr/model.py | 234 ++++------------ comfy/ldm/seedvr/vae.py | 255 +++++++----------- comfy/model_detection.py | 6 +- comfy/sd.py | 41 ++- comfy/supported_models.py | 4 + comfy/supported_models_base.py | 4 +- comfy_extras/nodes_seedvr.py | 56 +--- .../test_seedvr2_conditioning.py | 34 +-- .../test_seedvr2_post_processing.py | 18 +- tests-unit/comfy_test/model_detection_test.py | 26 +- .../comfy_test/test_seedvr2_internals.py | 2 +- tests-unit/comfy_test/test_seedvr2_model.py | 44 +++ .../comfy_test/test_seedvr2_vae_tiled.py | 58 ++-- 15 files changed, 313 insertions(+), 489 deletions(-) diff --git a/comfy/ldm/seedvr/attention.py b/comfy/ldm/seedvr/attention.py index 29ffded38..5d4054ab9 100644 --- a/comfy/ldm/seedvr/attention.py +++ b/comfy/ldm/seedvr/attention.py @@ -60,14 +60,7 @@ def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *a q_i = q_i.permute(1, 0, 2).unsqueeze(0) k_i = k_i.permute(1, 0, 2).unsqueeze(0) v_i = v_i.permute(1, 0, 2).unsqueeze(0) - out_dtype = q_i.dtype - if _attention.optimized_attention is _attention.attention_sage and q_i.dtype not in (torch.float16, torch.bfloat16): - q_i = q_i.to(torch.bfloat16) - k_i = k_i.to(torch.bfloat16) - v_i = v_i.to(torch.bfloat16) out_i = _attention.optimized_attention(q_i, k_i, v_i, heads, skip_reshape=True, skip_output_reshape=True) - if out_i.dtype != out_dtype: - out_i = out_i.to(out_dtype) out.append(out_i.squeeze(0).permute(1, 0, 2)) out = torch.cat(out, dim=0) diff --git a/comfy/ldm/seedvr/color_fix.py b/comfy/ldm/seedvr/color_fix.py index 7ddfc03af..440b3d26c 100644 --- a/comfy/ldm/seedvr/color_fix.py +++ b/comfy/ldm/seedvr/color_fix.py @@ -2,8 +2,6 @@ import torch import torch.nn.functional as F from torch import Tensor -from comfy.ldm.seedvr.model import safe_pad_operation -from comfy.ldm.seedvr.vae import safe_interpolate_operation from comfy.ldm.seedvr.constants import ( CIELAB_DELTA, CIELAB_KAPPA, @@ -28,7 +26,7 @@ def wavelet_blur(image: Tensor, radius): kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) kernel = kernel[None, None].repeat(num_channels, 1, 1, 1) - image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate') + image = F.pad(image, (radius, radius, radius, radius), mode='replicate') output = F.conv2d(image, kernel, groups=num_channels, dilation=radius) return output @@ -49,8 +47,7 @@ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor: if content_feat.shape != style_feat.shape: # Resize style to match content spatial dimensions if len(content_feat.shape) >= 3: - # safe_interpolate_operation handles FP16 conversion automatically - style_feat = safe_interpolate_operation( + style_feat = F.interpolate( style_feat, size=content_feat.shape[-2:], mode='bilinear', @@ -65,7 +62,7 @@ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor: del style_high_freq # Free memory immediately if content_high_freq.shape != style_low_freq.shape: - style_low_freq = safe_interpolate_operation( + style_low_freq = F.interpolate( style_low_freq, size=content_high_freq.shape[-2:], mode='bilinear', @@ -227,7 +224,7 @@ def lab_color_transfer( content_feat = wavelet_reconstruction(content_feat, style_feat) if content_feat.shape != style_feat.shape: - style_feat = safe_interpolate_operation( + style_feat = F.interpolate( style_feat, size=content_feat.shape[-2:], mode='bilinear', @@ -308,7 +305,7 @@ def wavelet_color_transfer(content_feat: Tensor, style_feat: Tensor) -> Tensor: def adain_color_transfer(content_feat: Tensor, style_feat: Tensor, eps: float = 1e-5) -> Tensor: if content_feat.shape != style_feat.shape: - style_feat = safe_interpolate_operation( + style_feat = F.interpolate( style_feat, size=content_feat.shape[-2:], mode='bilinear', diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index e7d3deb35..ee50449a4 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -1,7 +1,5 @@ from dataclasses import dataclass from typing import Optional, Tuple, Union, List, Dict, Any, Callable -import einops -from einops import rearrange import torch.nn.functional as F from math import ceil, pi import torch @@ -23,52 +21,6 @@ from comfy.ldm.seedvr.constants import ( SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS, ) import comfy.model_management -import numbers - -def _torch_float8_types(): - return tuple( - getattr(torch, name) - for name in ( - "float8_e4m3fn", - "float8_e4m3fnuz", - "float8_e5m2", - "float8_e5m2fnuz", - "float8_e8m0fnu", - ) - if hasattr(torch, name) - ) - -class CustomRMSNorm(nn.Module): - - def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, device=None, dtype=None): - super(CustomRMSNorm, self).__init__() - - if isinstance(normalized_shape, numbers.Integral): - normalized_shape = (normalized_shape,) - self.normalized_shape = torch.Size(normalized_shape) - self.eps = eps - self.elementwise_affine = elementwise_affine - - if self.elementwise_affine: - self.weight = nn.Parameter(torch.ones(*normalized_shape, device=device, dtype=dtype)) - else: - self.register_parameter('weight', None) - - def forward(self, input): - - dims = tuple(range(-len(self.normalized_shape), 0)) - - # Norm statistics in fp32 (fp16 variance underflows); activations return - # in the input dtype so downstream linears run at the model compute dtype. - normalized = input.float() - variance = normalized.pow(2).mean(dim=dims, keepdim=True) - rms = torch.sqrt(variance + self.eps) - - normalized = normalized / rms - - if self.elementwise_affine: - return (normalized * self.weight.to(torch.float32)).to(input.dtype) - return normalized.to(input.dtype) class Cache: def __init__(self, disable=False, prefix="", cache=None): @@ -81,12 +33,10 @@ class Cache: return fn() key = self.prefix + key - try: - result = self.cache[key] - except KeyError: + if key not in self.cache: result = fn() self.cache[key] = result - return result + return self.cache[key] def namespace(self, namespace: str): return Cache( @@ -144,15 +94,6 @@ class MMArg: vid: Any txt: Any -def safe_pad_operation(x, padding, mode='constant', value=0.0): - try: - return F.pad(x, padding, mode=mode, value=value) - except RuntimeError as e: - if "not implemented for" in str(e) and x.dtype in (torch.float16, torch.bfloat16): - return F.pad(x.float(), padding, mode=mode, value=value).to(x.dtype) - raise - - def get_args(key: str, args: List[Any]) -> List[Any]: return [getattr(v, key) if isinstance(v, MMArg) else v for v in args] @@ -235,8 +176,6 @@ class RotaryEmbedding(nn.Module): theta = 10000, max_freq = 10, learned_freq = False, - cache_if_possible = True, - cache_max_seq_len = 8192 ): super().__init__() @@ -247,12 +186,6 @@ class RotaryEmbedding(nn.Module): elif freqs_for == 'pixel': freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi - self.cache_if_possible = cache_if_possible - self.cache_max_seq_len = cache_max_seq_len - - self.register_buffer('cached_freqs', torch.zeros(cache_max_seq_len, dim), persistent = False) - self.cached_freqs_seq_len = 0 - self.freqs = nn.Parameter(freqs, requires_grad = learned_freq) self.learned_freq = learned_freq @@ -310,29 +243,10 @@ class RotaryEmbedding(nn.Module): seq_len: int | None = None, offset = 0 ): - should_cache = ( - self.cache_if_possible and - not self.learned_freq and - exists(seq_len) and - self.freqs_for != 'pixel' and - (offset + seq_len) <= self.cache_max_seq_len - ) - - if ( - should_cache and \ - exists(self.cached_freqs) and \ - (offset + seq_len) <= self.cached_freqs_seq_len - ): - return self.cached_freqs[offset:(offset + seq_len)].detach() - freqs = self.freqs freqs = torch.einsum('..., f -> ... f', t.type(freqs.dtype), freqs) - freqs = einops.repeat(freqs, '... n -> ... (n r)', r = 2) - - if should_cache and offset == 0: - self.cached_freqs[:seq_len] = freqs.detach() - self.cached_freqs_seq_len = seq_len + freqs = freqs.unsqueeze(-1).expand(*freqs.shape, 2).flatten(-2) return freqs @@ -346,7 +260,7 @@ class RotaryEmbeddingBase(nn.Module): ) freqs = self.rope.freqs del self.rope.freqs - self.rope.register_buffer("freqs", freqs.data) + self.rope.register_buffer("freqs", freqs.detach()) def get_axial_freqs(self, *dims): return self.rope.get_axial_freqs(*dims) @@ -371,12 +285,12 @@ class NaRotaryEmbedding3d(RotaryEmbedding3d): ]: freqs = cache("rope_freqs_3d", lambda: self.get_freqs(shape)) freqs = freqs.to(device=q.device) - q = rearrange(q, "L h d -> h L d") - k = rearrange(k, "L h d -> h L d") + q = q.transpose(0, 1) + k = k.transpose(0, 1) q = _apply_seedvr2_rotary_emb(freqs, q.float()).to(q.dtype) k = _apply_seedvr2_rotary_emb(freqs, k.float()).to(k.dtype) - q = rearrange(q, "h L d -> L h d") - k = rearrange(k, "h L d -> L h d") + q = q.transpose(0, 1) + k = k.transpose(0, 1) return q, k @torch._dynamo.disable @@ -407,11 +321,10 @@ class MMRotaryEmbeddingBase(RotaryEmbeddingBase): dim=dim // rope_dim, freqs_for="lang", theta=ROPE_THETA, - cache_if_possible=False, ) freqs = self.rope.freqs del self.rope.freqs - self.rope.register_buffer("freqs", freqs.data) + self.rope.register_buffer("freqs", freqs.detach()) self.mm = True def slice_at_dim(t, dim_slice: slice, *, dim): @@ -423,10 +336,10 @@ def slice_at_dim(t, dim_slice: slice, *, dim): # rotary embedding helper functions def rotate_half(x): - x = rearrange(x, '... (d r) -> ... d r', r = 2) + x = x.reshape(*x.shape[:-1], x.shape[-1] // 2, 2) x1, x2 = x.unbind(dim = -1) x = torch.stack((-x2, x1), dim = -1) - return rearrange(x, '... d r -> ... (d r)') + return x.flatten(-2) def exists(val): return val is not None @@ -465,7 +378,7 @@ def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor: cos = torch.cos(angles) sin = torch.sin(angles) out = torch.stack([cos, -sin, sin, cos], dim=-1) - return rearrange(out, "... d (i j) -> ... d i j", i=2, j=2) + return out.reshape(*out.shape[:-1], 2, 2) def _apply_rope1_partial(t: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: @@ -516,19 +429,19 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): vid_freqs = vid_freqs.to(target_device) if txt_freqs.device != target_device: txt_freqs = txt_freqs.to(target_device) - vid_q = rearrange(vid_q, "L h d -> h L d") - vid_k = rearrange(vid_k, "L h d -> h L d") + vid_q = vid_q.transpose(0, 1) + vid_k = vid_k.transpose(0, 1) vid_q = _apply_rope1_partial(vid_q, vid_freqs) vid_k = _apply_rope1_partial(vid_k, vid_freqs) - vid_q = rearrange(vid_q, "h L d -> L h d") - vid_k = rearrange(vid_k, "h L d -> L h d") + vid_q = vid_q.transpose(0, 1) + vid_k = vid_k.transpose(0, 1) - txt_q = rearrange(txt_q, "L h d -> h L d") - txt_k = rearrange(txt_k, "L h d -> h L d") + txt_q = txt_q.transpose(0, 1) + txt_k = txt_k.transpose(0, 1) txt_q = _apply_rope1_partial(txt_q, txt_freqs) txt_k = _apply_rope1_partial(txt_k, txt_freqs) - txt_q = rearrange(txt_q, "h L d -> L h d") - txt_k = rearrange(txt_k, "h L d -> L h d") + txt_q = txt_q.transpose(0, 1) + txt_k = txt_k.transpose(0, 1) return vid_q, vid_k, txt_q, txt_k @torch._dynamo.disable # Disable compilation: .tolist() is data-dependent and causes graph breaks @@ -684,7 +597,7 @@ def window( ): hid = unflatten(hid, hid_shape) hid = list(map(window_fn, hid)) - hid_windows = torch.tensor(list(map(len, hid)), device=hid_shape.device) + hid_windows = torch.as_tensor([len(x) for x in hid], device=hid_shape.device) hid, hid_shape = flatten(list(chain(*hid))) return hid, hid_shape, hid_windows @@ -747,8 +660,8 @@ class NaSwinAttention(NaMMAttention): ) vid_qkv_win = window_partition(vid_qkv) - vid_qkv_win = rearrange(vid_qkv_win, "l (o h d) -> l o h d", o=3, d=self.head_dim) - txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim) + vid_qkv_win = vid_qkv_win.reshape(vid_qkv_win.shape[0], 3, self.heads, self.head_dim) + txt_qkv = txt_qkv.reshape(txt_qkv.shape[0], 3, self.heads, self.head_dim) vid_q, vid_k, vid_v = vid_qkv_win.unbind(1) txt_q, txt_k, txt_v = txt_qkv.unbind(1) @@ -768,19 +681,19 @@ class NaSwinAttention(NaMMAttention): elif self.rope.mm: # repeat text q and k for window mmrope _, num_h, _ = txt_q.shape - txt_q_repeat = rearrange(txt_q, "l h d -> l (h d)") + txt_q_repeat = txt_q.flatten(1, 2) txt_q_repeat = unflatten(txt_q_repeat, txt_shape) txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count)] txt_q_repeat = list(chain(*txt_q_repeat)) txt_q_repeat, txt_shape_repeat = flatten(txt_q_repeat) - txt_q_repeat = rearrange(txt_q_repeat, "l (h d) -> l h d", h=num_h) + txt_q_repeat = txt_q_repeat.reshape(txt_q_repeat.shape[0], num_h, self.head_dim) - txt_k_repeat = rearrange(txt_k, "l h d -> l (h d)") + txt_k_repeat = txt_k.flatten(1, 2) txt_k_repeat = unflatten(txt_k_repeat, txt_shape) txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count)] txt_k_repeat = list(chain(*txt_k_repeat)) txt_k_repeat, _ = flatten(txt_k_repeat) - txt_k_repeat = rearrange(txt_k_repeat, "l (h d) -> l h d", h=num_h) + txt_k_repeat = txt_k_repeat.reshape(txt_k_repeat.shape[0], num_h, self.head_dim) vid_q, vid_k, txt_q, txt_k = self.rope( vid_q, vid_k, window_shape, txt_q_repeat, txt_k_repeat, txt_shape_repeat, cache_win @@ -799,16 +712,16 @@ class NaSwinAttention(NaMMAttention): v=concat_win(vid_v, txt_v), heads=self.heads, skip_reshape=True, skip_output_reshape=True, cu_seqlens_q=cache_win( - "vid_seqlens_q", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() + "vid_seqlens_q", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int() ), cu_seqlens_k=cache_win( - "vid_seqlens_k", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() + "vid_seqlens_k", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int() ), ) vid_out, txt_out = unconcat_win(out) - vid_out = rearrange(vid_out, "l h d -> l (h d)") - txt_out = rearrange(txt_out, "l h d -> l (h d)") + vid_out = vid_out.flatten(1, 2) + txt_out = txt_out.flatten(1, 2) vid_out = window_reverse(vid_out) vid_out, txt_out = self.proj_out(vid_out, txt_out) @@ -1005,7 +918,9 @@ class PatchOut(nn.Module): ) -> torch.Tensor: t, h, w = self.patch_size vid = self.proj(vid) - vid = rearrange(vid, "b T H W (t h w c) -> b c (T t) (H h) (W w)", t=t, h=h, w=w) + b, T, H, W, channels = vid.shape + c = channels // (t * h * w) + vid = vid.view(b, T, H, W, t, h, w, c).permute(0, 7, 1, 4, 2, 5, 3, 6).reshape(b, c, T * t, H * h, W * w) if t > 1: vid = vid[:, :, (t - 1) :] return vid @@ -1015,7 +930,7 @@ class NaPatchOut(PatchOut): self, vid: torch.FloatTensor, # l c vid_shape: torch.LongTensor, - cache: Cache = Cache(disable=True), # for test + cache: Cache = Cache(disable=True), vid_shape_before_patchify = None ) -> Tuple[ torch.FloatTensor, @@ -1028,7 +943,9 @@ class NaPatchOut(PatchOut): if not (t == h == w == 1): vid = unflatten(vid, vid_shape) for i in range(len(vid)): - vid[i] = rearrange(vid[i], "T H W (t h w c) -> (T t) (H h) (W w) c", t=t, h=h, w=w) + T, H, W, channels = vid[i].shape + c = channels // (t * h * w) + vid[i] = vid[i].view(T, H, W, t, h, w, c).permute(0, 3, 1, 4, 2, 5, 6).reshape(T * t, H * h, W * w, c) if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: vid[i] = vid[i][(t - vid_shape_before_patchify[i, 0] % t) :] vid, vid_shape = flatten(vid) @@ -1056,7 +973,8 @@ class PatchIn(nn.Module): if t > 1: assert vid.size(2) % t == 1 vid = torch.cat([vid[:, :, :1]] * (t - 1) + [vid], dim=2) - vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w) + b, c, Tt, Hh, Ww = vid.shape + vid = vid.view(b, c, Tt // t, t, Hh // h, h, Ww // w, w).permute(0, 2, 4, 6, 3, 5, 7, 1).reshape(b, Tt // t, Hh // h, Ww // w, t * h * w * c) vid = self.proj(vid) return vid @@ -1065,7 +983,7 @@ class NaPatchIn(PatchIn): self, vid: torch.Tensor, # l c vid_shape: torch.LongTensor, - cache: Cache = Cache(disable=True), # for test + cache: Cache = Cache(disable=True), ) -> torch.Tensor: cache = cache.namespace("patch") vid_shape_before_patchify = cache("vid_shape_before_patchify", lambda: vid_shape) @@ -1075,7 +993,8 @@ class NaPatchIn(PatchIn): for i in range(len(vid)): if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: vid[i] = torch.cat([vid[i][:1]] * (t - vid[i].size(0) % t) + [vid[i]], dim=0) - vid[i] = rearrange(vid[i], "(T t) (H h) (W w) c -> T H W (t h w c)", t=t, h=h, w=w) + Tt, Hh, Ww, c = vid[i].shape + vid[i] = vid[i].view(Tt // t, t, Hh // h, h, Ww // w, w, c).permute(0, 2, 4, 1, 3, 5, 6).reshape(Tt // t, Hh // h, Ww // w, t * h * w * c) vid, vid_shape = flatten(vid) vid = self.proj(vid) @@ -1102,17 +1021,14 @@ class AdaSingle(nn.Module): self.emb_dim = emb_dim self.layers = layers - param_kwargs = {"device": device} - fp8_types = _torch_float8_types() - if dtype is not None and dtype not in fp8_types: - param_kwargs["dtype"] = dtype + param_kwargs = {"device": device, "dtype": dtype} for l in layers: if "in" in modes: - self.register_parameter(f"{l}_shift", nn.Parameter(torch.zeros(dim, **param_kwargs))) - self.register_parameter(f"{l}_scale", nn.Parameter(torch.ones(dim, **param_kwargs))) + self.register_parameter(f"{l}_shift", nn.Parameter(torch.empty(dim, **param_kwargs))) + self.register_parameter(f"{l}_scale", nn.Parameter(torch.empty(dim, **param_kwargs))) if "out" in modes: - self.register_parameter(f"{l}_gate", nn.Parameter(torch.zeros(dim, **param_kwargs))) + self.register_parameter(f"{l}_gate", nn.Parameter(torch.empty(dim, **param_kwargs))) def forward( self, @@ -1125,7 +1041,7 @@ class AdaSingle(nn.Module): hid_len: Optional[torch.LongTensor] = None, # b ) -> torch.FloatTensor: idx = self.layers.index(layer) - emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :] + emb = emb.reshape(emb.shape[0], -1, len(self.layers), 3)[:, :, idx, :] emb = expand_dims(emb, 1, hid.ndim + 1) if hid_len is not None: @@ -1145,17 +1061,6 @@ class AdaSingle(nn.Module): getattr(self, f"{layer}_gate", None), ) - fp8_types = _torch_float8_types() - if fp8_types: - target_dtype = hid.dtype - - if shiftB is not None and shiftB.dtype in fp8_types: - shiftB = shiftB.to(target_dtype) - if scaleB is not None and scaleB.dtype in fp8_types: - scaleB = scaleB.to(target_dtype) - if gateB is not None and gateB.dtype in fp8_types: - gateB = gateB.to(target_dtype) - if mode == "in": return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB) if mode == "out": @@ -1213,7 +1118,7 @@ def flatten( torch.LongTensor, # (b n) ]: assert len(hid) > 0 - shape = torch.stack([torch.tensor(x.shape[:-1], device=hid[0].device) for x in hid]) + shape = torch.as_tensor([x.shape[:-1] for x in hid], device=hid[0].device) hid = torch.cat([x.flatten(0, -2) for x in hid]) return hid, shape @@ -1227,19 +1132,6 @@ def unflatten( hid = [x.unflatten(0, s.tolist()) for x, s in zip(hid, hid_shape)] return hid -def repeat( - hid: torch.FloatTensor, # (L c) - hid_shape: torch.LongTensor, # (b n) - pattern: str, - **kwargs: Dict[str, torch.LongTensor], # (b) -) -> Tuple[ - torch.FloatTensor, - torch.LongTensor, -]: - hid = unflatten(hid, hid_shape) - kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))] - return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)]) - class NaDiT(nn.Module): def __init__( @@ -1275,23 +1167,11 @@ class NaDiT(nn.Module): emb_dim = vid_dim * 6 window = num_layers * [(4,3,3)] ada = AdaSingle - norm = CustomRMSNorm - qk_norm = CustomRMSNorm + norm = operations.RMSNorm + qk_norm = operations.RMSNorm super().__init__() - # ``torch.empty`` returns uninitialized memory, not zeros. The - # SeedVR2Conditioning fail-loud guard at - # ``comfy_extras/nodes_seedvr.py`` distinguishes "buffer was loaded" - # from "buffer was never populated by the file" by checking - # ``positive_conditioning.abs().sum() == 0``. That sentinel is only - # reliable if the post-construction buffer state is deterministically - # zero, so explicitly zero-fill here rather than relying on the - # allocator's zero-on-alloc behavior (allocator-dependent and not - # contractual). When ``load_state_dict`` populates these buffers - # from a properly-baked SeedVR2 .safetensors, the in-place copy - # overwrites the zeros with the universal SeedVR2 conditioning - # tensors (shape (58, 5120) and (64, 5120) bf16). - self.register_buffer("positive_conditioning", torch.zeros((58, 5120), device=device, dtype=dtype)) - self.register_buffer("negative_conditioning", torch.zeros((64, 5120), device=device, dtype=dtype)) + self.register_buffer("positive_conditioning", torch.empty((58, 5120), device=device, dtype=dtype)) + self.register_buffer("negative_conditioning", torch.empty((64, 5120), device=device, dtype=dtype)) self.vid_in = NaPatchIn( in_channels=vid_in_channels, patch_size=patch_size, @@ -1354,7 +1234,7 @@ class NaDiT(nn.Module): self.vid_out_norm = None if vid_out_norm is not None: - self.vid_out_norm = CustomRMSNorm( + self.vid_out_norm = operations.RMSNorm( normalized_shape=vid_dim, eps=norm_eps, elementwise_affine=True, @@ -1369,7 +1249,7 @@ class NaDiT(nn.Module): ) def _resolve_text_conditioning(self, context, cond_or_uncond=None): - if context is None or getattr(context, "numel", lambda: None)() == 0: + if context is None or context.numel() == 0: context = self.positive_conditioning return flatten([context]) if NaDiT._seedvr2_is_single_conditioning_branch(cond_or_uncond): @@ -1407,7 +1287,7 @@ class NaDiT(nn.Module): x, timestep, context, # l c - disable_cache: bool = False, # for test # TODO ? // gives an error when set to True + disable_cache: bool = False, **kwargs ): transformer_options = kwargs.get("transformer_options", {}) @@ -1483,5 +1363,5 @@ class NaDiT(nn.Module): vid = unflatten(vid, vid_shape) out = torch.stack(vid) out = out.movedim(-1, 1) - out = rearrange(out, "b c t h w -> b (c t) h w") + out = out.reshape(out.shape[0], out.shape[1] * out.shape[2], out.shape[3], out.shape[4]) return self._swap_pos_neg_halves(out, transformer_options.get("cond_or_uncond")) diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index 501896516..5daab022a 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -1,15 +1,11 @@ -from contextlib import nullcontext from typing import Literal, Optional, Tuple -import gc import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange from torch import Tensor from contextlib import contextmanager from comfy.utils import ProgressBar -from comfy.ldm.seedvr.model import safe_pad_operation from comfy.ldm.seedvr.constants import ( BYTEDANCE_BLOCK_OUT_CHANNELS, BYTEDANCE_GN_CHUNKS_FP16, @@ -58,13 +54,6 @@ def _seedvr2_clamped_spatial_overlap(overlap, tile_size): return min(overlap, tile_size - 1) -def _seedvr2_clear_temporal_memory(model): - for module in model.modules(): - if hasattr(module, "memory"): - module.memory = None - - -@torch.inference_mode() def tiled_vae( x, vae_model, @@ -75,10 +64,6 @@ def tiled_vae( encode=True, **kwargs, ): - gc.collect() - comfy.model_management.soft_empty_cache() - - x = x.to(next(vae_model.parameters()).dtype) if x.ndim != 5: x = x.unsqueeze(2) @@ -121,7 +106,6 @@ def tiled_vae( count = None def run_temporal_chunks(spatial_tile, model=vae_model, device=storage_device): device = torch.device(device) - _seedvr2_clear_temporal_memory(model) t_chunk = spatial_tile.to(device=device, dtype=next(model.parameters()).dtype, non_blocking=True).contiguous() old_device = getattr(model, "device", None) model.device = device @@ -133,7 +117,7 @@ def tiled_vae( setattr(model, slicing_attr, slicing_min_size) try: if encode: - out = model.encode(t_chunk)[0] + out = model.encode(t_chunk) else: out = model.decode_(t_chunk) finally: @@ -141,8 +125,6 @@ def tiled_vae( setattr(model, slicing_attr, old_slicing_min_size) if old_device is not None: model.device = old_device - if isinstance(out, (tuple, list)): - out = out[0] if out.ndim == 4: out = out.unsqueeze(2) return out.to(storage_device) @@ -169,8 +151,6 @@ def tiled_vae( bar = ProgressBar(total_tiles) single_spatial_tile = h <= ti_h and w <= ti_w - _seedvr2_clear_temporal_memory(vae_model) - def run_tile(tile_index, tile_range): y_idx, y_end, x_idx, x_end = tile_range tile_x = x[:, :, :, y_idx:y_end, x_idx:x_end] @@ -186,7 +166,6 @@ def tiled_vae( if single_spatial_tile: result = tile_out[:, :, :target_d, :target_h, :target_w] - _seedvr2_clear_temporal_memory(vae_model) if result.device != x.device: result = result.to(x.device).to(x.dtype) if x.shape[2] == 1 and sf_t == 1: @@ -241,7 +220,6 @@ def tiled_vae( bar.update(1) result.div_(count.clamp(min=1e-6)) - _seedvr2_clear_temporal_memory(vae_model) if result.device != x.device: result = result.to(x.device).to(x.dtype) @@ -336,7 +314,6 @@ class Attention(nn.Module): eps: float = 1e-5, rescale_output_factor: float = 1.0, residual_connection: bool = False, - _from_deprecated_attn_block: bool = False, out_dim: int = None, pre_only=False, ): @@ -356,10 +333,6 @@ class Attention(nn.Module): self.out_dim = out_dim if out_dim is not None else query_dim self.pre_only = pre_only - # we make use of this private variable to know whether this class is loaded - # with an deprecated state dict so that we can convert it on the fly - self._from_deprecated_attn_block = _from_deprecated_attn_block - self.scale_qk = scale_qk self.scale = dim_head**-0.5 if self.scale_qk else 1.0 @@ -480,21 +453,21 @@ def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: input_dtype = x.dtype if isinstance(norm_layer, (ops.LayerNorm, ops.RMSNorm)): if x.ndim == 4: - x = rearrange(x, "b c h w -> b h w c") + x = x.permute(0, 2, 3, 1) x = norm_layer(x) - x = rearrange(x, "b h w c -> b c h w") + x = x.permute(0, 3, 1, 2) return x.to(input_dtype) if x.ndim == 5: - x = rearrange(x, "b c t h w -> b t h w c") + x = x.permute(0, 2, 3, 4, 1) x = norm_layer(x) - x = rearrange(x, "b t h w c -> b c t h w") + x = x.permute(0, 4, 1, 2, 3) return x.to(input_dtype) if isinstance(norm_layer, (ops.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)): if x.ndim <= 4: return norm_layer(x).to(input_dtype) if x.ndim == 5: - t = x.size(2) - x = rearrange(x, "b c t h w -> (b t) c h w") + b, c, t, h, w = x.shape + x = x.transpose(1, 2).reshape(b * t, c, h, w) memory_occupy = x.numel() * x.element_size() / 1024**3 if isinstance(norm_layer, ops.GroupNorm) and memory_occupy > get_norm_limit(): num_chunks = min(BYTEDANCE_GN_CHUNKS_FP16 if x.element_size() == 2 else BYTEDANCE_GN_CHUNKS_FP32, norm_layer.num_groups) @@ -504,54 +477,16 @@ def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: x = list(x.chunk(num_chunks, dim=1)) weights = norm_layer.weight.chunk(num_chunks, dim=0) biases = norm_layer.bias.chunk(num_chunks, dim=0) - for i, (w, b) in enumerate(zip(weights, biases)): - x[i] = F.group_norm(x[i], num_groups_per_chunk, w, b, norm_layer.eps) + for i, (w, bias) in enumerate(zip(weights, biases)): + x[i] = F.group_norm(x[i], num_groups_per_chunk, w, bias, norm_layer.eps) x[i] = x[i].to(input_dtype) x = torch.cat(x, dim=1) else: x = norm_layer(x) - x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + x = x.reshape((b, t, x.size(1), x.size(2), x.size(3))).transpose(1, 2) return x.to(input_dtype) raise NotImplementedError -def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None): - problematic_modes = ['bilinear', 'bicubic', 'trilinear'] - - if mode in problematic_modes: - try: - return F.interpolate( - x, - size=size, - scale_factor=scale_factor, - mode=mode, - align_corners=align_corners, - recompute_scale_factor=recompute_scale_factor - ) - except RuntimeError as e: - if ("not implemented for 'Half'" in str(e) or - "compute_indices_weights" in str(e)): - original_dtype = x.dtype - return F.interpolate( - x.float(), - size=size, - scale_factor=scale_factor, - mode=mode, - align_corners=align_corners, - recompute_scale_factor=recompute_scale_factor - ).to(original_dtype) - else: - raise e - else: - # Pour 'nearest' et autres modes compatibles, pas de fix nécessaire - return F.interpolate( - x, - size=size, - scale_factor=scale_factor, - mode=mode, - align_corners=align_corners, - recompute_scale_factor=recompute_scale_factor - ) - _receptive_field_t = Literal["half", "full"] def extend_head(tensor, times: int = 2, memory = None): @@ -585,7 +520,6 @@ class InflatedCausalConv3d(ops.Conv3d): **kwargs, ): self.inflation_mode = inflation_mode - self.memory = None super().__init__(*args, **kwargs) self.temporal_padding = self.padding[0] self.padding = (0, *self.padding[1:]) @@ -620,18 +554,19 @@ class InflatedCausalConv3d(ops.Conv3d): return super().forward(x) # Compute tensor shape after concat & padding. - shape = torch.tensor(x.size()) + shape = list(x.size()) if prev_cache is not None: shape[split_dim - 1] += prev_cache.size(split_dim - 1) - shape[-3:] += torch.tensor(padding).view(3, 2).sum(-1).flip(0) - memory_occupy = shape.prod() * x.element_size() / 1024**3 # GiB + for i, pad_sum in enumerate((padding[4] + padding[5], padding[2] + padding[3], padding[0] + padding[1])): + shape[-3 + i] += pad_sum + memory_occupy = math.prod(shape) * x.element_size() / 1024**3 # GiB if memory_occupy < self.memory_limit or split_dim == x.ndim: x_concat = x if prev_cache is not None: x_concat = torch.cat([prev_cache, x], dim=split_dim - 1) def pad_and_forward(): - padded = safe_pad_operation(x_concat, padding, mode='constant', value=0.0) + padded = F.pad(x_concat, padding, mode='constant', value=0.0) if not padded.is_contiguous(): padded = padded.contiguous() with ignore_padding(self): @@ -689,46 +624,57 @@ class InflatedCausalConv3d(ops.Conv3d): def forward( self, input, - memory_state: MemoryState = MemoryState.UNSET + memory_state: MemoryState = MemoryState.UNSET, + memory_cache = None, ) -> Tensor: assert memory_state != MemoryState.UNSET + if memory_cache is None: + memory_cache = {} if memory_state != MemoryState.ACTIVE: - self.memory = None + memory_cache.pop(self, None) if ( math.isinf(self.memory_limit) and torch.is_tensor(input) ): - return self.basic_forward(input, memory_state) - return self.slicing_forward(input, memory_state) + return self.basic_forward(input, memory_state, memory_cache) + return self.slicing_forward(input, memory_state, memory_cache) - def basic_forward(self, input: Tensor, memory_state: MemoryState = MemoryState.UNSET): + def basic_forward(self, input: Tensor, memory_state: MemoryState = MemoryState.UNSET, memory_cache = None): mem_size = self.stride[0] - self.kernel_size[0] - if (self.memory is not None) and (memory_state == MemoryState.ACTIVE): - input = extend_head(input, memory=self.memory, times=-1) + memory = memory_cache.get(self) if memory_cache is not None else None + if (memory is not None) and (memory_state == MemoryState.ACTIVE): + input = extend_head(input, memory=memory, times=-1) else: input = extend_head(input, times=self.temporal_padding * 2) - memory = ( + next_memory = ( input[:, :, mem_size:].detach() if (mem_size != 0 and memory_state != MemoryState.DISABLED) else None ) - if memory_state != MemoryState.DISABLED: - self.memory = memory + if memory_cache is not None and memory_state != MemoryState.DISABLED: + if next_memory is None: + memory_cache.pop(self, None) + else: + memory_cache[self] = next_memory return super().forward(input) def slicing_forward( self, input, memory_state: MemoryState = MemoryState.UNSET, + memory_cache = None, ) -> Tensor: + if memory_cache is None: + memory_cache = {} squeeze_out = False if torch.is_tensor(input): input = [input] squeeze_out = True cache_size = self.kernel_size[0] - self.stride[0] + memory = memory_cache.get(self) if memory_cache is not None else None cache = cache_send_recv( - input, cache_size=cache_size, memory=self.memory, times=self.temporal_padding * 2 + input, cache_size=cache_size, memory=memory, times=self.temporal_padding * 2 ) # Single GPU inference - simplified memory management @@ -740,7 +686,7 @@ class InflatedCausalConv3d(ops.Conv3d): input[0] = torch.cat([cache, input[0]], dim=2) cache = None if cache_size <= input[-1].size(2): - self.memory = input[-1][:, :, -cache_size:].detach().contiguous() + memory_cache[self] = input[-1][:, :, -cache_size:].detach().contiguous() padding = tuple(x for x in reversed(self.padding) for _ in range(2)) for i in range(len(input)): @@ -802,17 +748,10 @@ class Upsample3D(nn.Module): self.temporal_ratio = 2 if temporal_up else 1 self.spatial_ratio = 2 if spatial_up else 1 - # [Override] MAGViT v2 learnable upsample upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio self.upscale_conv = ops.Conv3d( self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0 ) - identity = ( - torch.eye(self.channels) - .repeat(upscale_ratio, 1) - .reshape_as(self.upscale_conv.weight) - ) - self.upscale_conv.weight.data.copy_(identity) self.conv = conv @@ -820,23 +759,27 @@ class Upsample3D(nn.Module): self, hidden_states: torch.FloatTensor, memory_state=None, + memory_cache=None, **kwargs, ) -> torch.FloatTensor: assert hidden_states.shape[1] == self.channels hidden_states = self.upscale_conv(hidden_states) - hidden_states = rearrange( - hidden_states, - "b (x y z c) f h w -> b c (f z) (h x) (w y)", - x=self.spatial_ratio, - y=self.spatial_ratio, - z=self.temporal_ratio, + b, channels, f, h, w = hidden_states.shape + c = channels // (self.spatial_ratio * self.spatial_ratio * self.temporal_ratio) + hidden_states = hidden_states.view(b, self.spatial_ratio, self.spatial_ratio, self.temporal_ratio, c, f, h, w) + hidden_states = hidden_states.permute(0, 4, 5, 3, 6, 1, 7, 2).reshape( + b, + c, + f * self.temporal_ratio, + h * self.spatial_ratio, + w * self.spatial_ratio, ) if self.temporal_up and memory_state != MemoryState.ACTIVE: hidden_states = remove_head(hidden_states) - hidden_states = self.conv(hidden_states, memory_state=memory_state) + hidden_states = self.conv(hidden_states, memory_state=memory_state, memory_cache=memory_cache) return hidden_states @@ -879,6 +822,7 @@ class Downsample3D(nn.Module): self, hidden_states: torch.FloatTensor, memory_state = None, + memory_cache = None, **kwargs, ) -> torch.FloatTensor: @@ -890,11 +834,11 @@ class Downsample3D(nn.Module): if self.spatial_down: pad = (0, 1, 0, 1) - hidden_states = safe_pad_operation(hidden_states, pad, mode="constant", value=0) + hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) assert hidden_states.shape[1] == self.channels - hidden_states = self.conv(hidden_states, memory_state=memory_state) + hidden_states = self.conv(hidden_states, memory_state=memory_state, memory_cache=memory_cache) return hidden_states @@ -962,7 +906,7 @@ class ResnetBlock3D(nn.Module): ) def forward( - self, input_tensor, temb, memory_state = None, **kwargs + self, input_tensor, temb, memory_state = None, memory_cache = None, **kwargs ): hidden_states = input_tensor @@ -970,7 +914,7 @@ class ResnetBlock3D(nn.Module): hidden_states = self.nonlinearity(hidden_states) - hidden_states = self.conv1(hidden_states, memory_state=memory_state) + hidden_states = self.conv1(hidden_states, memory_state=memory_state, memory_cache=memory_cache) if self.time_emb_proj is not None: if not self.skip_time_act: @@ -985,10 +929,10 @@ class ResnetBlock3D(nn.Module): hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states, memory_state=memory_state) + hidden_states = self.conv2(hidden_states, memory_state=memory_state, memory_cache=memory_cache) if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor, memory_state=memory_state) + input_tensor = self.conv_shortcut(input_tensor, memory_state=memory_state, memory_cache=memory_cache) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor @@ -1055,15 +999,16 @@ class DownEncoderBlock3D(nn.Module): self, hidden_states: torch.FloatTensor, memory_state = None, + memory_cache = None, **kwargs, ) -> torch.FloatTensor: for resnet, temporal in zip(self.resnets, self.temporal_modules): - hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state) + hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state, memory_cache=memory_cache) hidden_states = temporal(hidden_states) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, memory_state=memory_state) + hidden_states = downsampler(hidden_states, memory_state=memory_state, memory_cache=memory_cache) return hidden_states @@ -1132,15 +1077,16 @@ class UpDecoderBlock3D(nn.Module): self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, - memory_state=None + memory_state=None, + memory_cache=None, ) -> torch.FloatTensor: for resnet, temporal in zip(self.resnets, self.temporal_modules): - hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state) + hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state, memory_cache=memory_cache) hidden_states = temporal(hidden_states) if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, memory_state=memory_state) + hidden_states = upsampler(hidden_states, memory_state=memory_state, memory_cache=memory_cache) return hidden_states @@ -1203,7 +1149,6 @@ class UNetMidBlock3D(nn.Module): residual_connection=True, bias=True, upcast_softmax=True, - _from_deprecated_attn_block=True, ) ) else: @@ -1226,17 +1171,16 @@ class UNetMidBlock3D(nn.Module): self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, temb=None, memory_state=None): + def forward(self, hidden_states, temb=None, memory_state=None, memory_cache=None): video_length, frame_height, frame_width = hidden_states.size()[-3:] - hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state) + hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state, memory_cache=memory_cache) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: - hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + b, c, f, h, w = hidden_states.shape + hidden_states = hidden_states.transpose(1, 2).reshape(b * f, c, h, w) hidden_states = attn(hidden_states, temb=temb) - hidden_states = rearrange( - hidden_states, "(b f) c h w -> b c f h w", f=video_length - ) - hidden_states = resnet(hidden_states, temb, memory_state=memory_state) + hidden_states = hidden_states.reshape(b, video_length, c, h, w).transpose(1, 2) + hidden_states = resnet(hidden_states, temb, memory_state=memory_state, memory_cache=memory_cache) return hidden_states @@ -1327,22 +1271,23 @@ class Encoder3D(nn.Module): def forward( self, sample: torch.FloatTensor, - memory_state = None + memory_state = None, + memory_cache = None, ) -> torch.FloatTensor: r"""The forward method of the `Encoder` class.""" sample = sample.to(next(self.parameters()).device) - sample = self.conv_in(sample, memory_state = memory_state) + sample = self.conv_in(sample, memory_state=memory_state, memory_cache=memory_cache) # down for down_block in self.down_blocks: - sample = down_block(sample, memory_state=memory_state) + sample = down_block(sample, memory_state=memory_state, memory_cache=memory_cache) # middle - sample = self.mid_block(sample, memory_state=memory_state) + sample = self.mid_block(sample, memory_state=memory_state, memory_cache=memory_cache) # post-process sample = causal_norm_wrapper(self.conv_norm_out, sample) sample = self.conv_act(sample) - sample = self.conv_out(sample, memory_state = memory_state) + sample = self.conv_out(sample, memory_state=memory_state, memory_cache=memory_cache) return sample @@ -1436,24 +1381,25 @@ class Decoder3D(nn.Module): sample: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None, memory_state = None, + memory_cache = None, ) -> torch.FloatTensor: sample = sample.to(next(self.parameters()).device) - sample = self.conv_in(sample, memory_state=memory_state) + sample = self.conv_in(sample, memory_state=memory_state, memory_cache=memory_cache) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype # middle - sample = self.mid_block(sample, latent_embeds, memory_state=memory_state) + sample = self.mid_block(sample, latent_embeds, memory_state=memory_state, memory_cache=memory_cache) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: - sample = up_block(sample, latent_embeds, memory_state=memory_state) + sample = up_block(sample, latent_embeds, memory_state=memory_state, memory_cache=memory_cache) # post-process sample = causal_norm_wrapper(self.conv_norm_out, sample) sample = self.conv_act(sample) - sample = self.conv_out(sample, memory_state=memory_state) + sample = self.conv_out(sample, memory_state=memory_state, memory_cache=memory_cache) return sample @@ -1529,22 +1475,23 @@ class VideoAutoencoderKL(nn.Module): return decoded def _encode( - self, x, memory_state = MemoryState.DISABLED + self, x, memory_state = MemoryState.DISABLED, memory_cache = None ) -> torch.Tensor: _x = x.to(self.device) - h = self.encoder(_x, memory_state=memory_state) + h = self.encoder(_x, memory_state=memory_state, memory_cache=memory_cache) return h.to(x.device) def _decode( - self, z, memory_state = MemoryState.DISABLED + self, z, memory_state = MemoryState.DISABLED, memory_cache = None ) -> torch.Tensor: _z = z.to(self.device) - output = self.decoder(_z, memory_state=memory_state) + output = self.decoder(_z, memory_state=memory_state, memory_cache=memory_cache) return output.to(z.device) def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: sp_size =1 if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size: + memory_cache = {} split_size = max( self.slicing_sample_min_size * sp_size, getattr(self, "temporal_downsample_factor", 1), @@ -1558,17 +1505,14 @@ class VideoAutoencoderKL(nn.Module): self._encode( torch.cat((x[:, :, :1], x_slices[0]), dim=2), memory_state=MemoryState.INITIALIZING, + memory_cache=memory_cache, ) ] for x_idx in range(1, len(x_slices)): encoded_slices.append( - self._encode(x_slices[x_idx], memory_state=MemoryState.ACTIVE) + self._encode(x_slices[x_idx], memory_state=MemoryState.ACTIVE, memory_cache=memory_cache) ) out = torch.cat(encoded_slices, dim=2) - modules_with_memory = [m for m in self.modules() - if isinstance(m, InflatedCausalConv3d) and m.memory is not None] - for m in modules_with_memory: - m.memory = None return out else: return self._encode(x) @@ -1576,22 +1520,20 @@ class VideoAutoencoderKL(nn.Module): def slicing_decode(self, z: torch.Tensor) -> torch.Tensor: sp_size = 1 if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size: + memory_cache = {} z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2) decoded_slices = [ self._decode( torch.cat((z[:, :, :1], z_slices[0]), dim=2), - memory_state=MemoryState.INITIALIZING + memory_state=MemoryState.INITIALIZING, + memory_cache=memory_cache, ) ] for z_idx in range(1, len(z_slices)): decoded_slices.append( - self._decode(z_slices[z_idx], memory_state=MemoryState.ACTIVE) + self._decode(z_slices[z_idx], memory_state=MemoryState.ACTIVE, memory_cache=memory_cache) ) out = torch.cat(decoded_slices, dim=2) - modules_with_memory = [m for m in self.modules() - if isinstance(m, InflatedCausalConv3d) and m.memory is not None] - for m in modules_with_memory: - m.memory = None return out else: return self._decode(z) @@ -1612,32 +1554,25 @@ class VideoAutoencoderKL(nn.Module): return _unwrap(self.decode_(latent)) class VideoAutoencoderKLWrapper(VideoAutoencoderKL): - # Signals to comfy.sd.VAE that this model performs its own VAE tiling, so the - # generic tiled-decode/encode dispatch defers to decode_tiled/encode_tiled below. - comfy_handles_tiling = True - def __init__( self, *args, spatial_downsample_factor = 8, temporal_downsample_factor = 4, - freeze_encoder = True, **kwargs, ): self.spatial_downsample_factor = spatial_downsample_factor self.temporal_downsample_factor = temporal_downsample_factor - self.freeze_encoder = freeze_encoder self.enable_tiling = False super().__init__(*args, **kwargs) self.set_memory_limit(BYTEDANCE_VAE_CONV_MEM_GIB, BYTEDANCE_VAE_NORM_MEM_GIB) def forward(self, x: torch.FloatTensor): - with torch.no_grad() if self.freeze_encoder else nullcontext(): - z, p = self.encode(x) + z, p = self._encode_with_raw_latent(x) x = self.decode(z) return x, z, p - def encode(self, x, orig_dims=None): + def _encode_with_raw_latent(self, x): if x.ndim == 4: x = x.unsqueeze(2) x = x.to(dtype=next(self.parameters()).dtype) @@ -1646,6 +1581,10 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): z = p.squeeze(2) return z, p + def encode(self, x, orig_dims=None): + z, _ = self._encode_with_raw_latent(x) + return z + def decode(self, z, seedvr2_tiling=None): seedvr2_tiling = {} if seedvr2_tiling is None else seedvr2_tiling if not isinstance(seedvr2_tiling, dict): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index bcca99251..bf44b832c 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -1151,9 +1151,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): return unet_config -def model_config_from_unet_config(unet_config, state_dict=None): +def model_config_from_unet_config(unet_config, state_dict=None, unet_key_prefix=""): for model_config in comfy.supported_models.models: - if model_config.matches(unet_config, state_dict): + if model_config.matches(unet_config, state_dict, unet_key_prefix=unet_key_prefix): return model_config(unet_config) logging.error("no match {}".format(unet_config)) @@ -1163,7 +1163,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal unet_config = detect_unet_config(state_dict, unet_key_prefix, metadata=metadata) if unet_config is None: return None - model_config = model_config_from_unet_config(unet_config, state_dict) + model_config = model_config_from_unet_config(unet_config, state_dict, unet_key_prefix) if model_config is None and use_base_if_no_match: model_config = comfy.supported_models_base.BASE(unet_config) diff --git a/comfy/sd.py b/comfy/sd.py index 6e1340ea8..06c6196d3 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,4 +1,3 @@ -import inspect import json import torch from enum import Enum @@ -500,6 +499,8 @@ class VAE: self.upscale_index_formula = None self.extra_1d_channel = None self.crop_input = True + self.handles_tiling = False + self.format_encoded = None self.audio_sample_rate = 44100 @@ -554,6 +555,8 @@ class VAE: self.memory_used_decode = lambda shape, dtype: self.first_stage_model.comfy_memory_used_decode(shape) self.memory_used_encode = lambda shape, dtype: (max(shape[2], 5) * shape[3] * shape[4] * 64) * model_management.dtype_size(dtype) self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + self.handles_tiling = True + self.format_encoded = self.first_stage_model.comfy_format_encoded self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) self.downscale_index_formula = (4, 8, 8) self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) @@ -1118,7 +1121,7 @@ class VAE: if dims == 1 or self.extra_1d_channel is not None: pixel_samples = self.decode_tiled_1d(samples_in) elif dims == 2: - if getattr(self.first_stage_model, "comfy_handles_tiling", False): + if self.handles_tiling: tile = 256 // self.spacial_compression_decode() overlap = tile // 4 pixel_samples = self._decode_tiled_owned(samples_in, tile_x=tile, tile_y=tile, overlap=overlap) @@ -1127,7 +1130,7 @@ class VAE: elif dims == 3: tile = 256 // self.spacial_compression_decode() overlap = tile // 4 - if getattr(self.first_stage_model, "comfy_handles_tiling", False): + if self.handles_tiling: pixel_samples = self._decode_tiled_owned(samples_in, tile_x=tile, tile_y=tile, overlap=overlap) else: pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) @@ -1149,7 +1152,7 @@ class VAE: args["overlap"] = overlap with model_management.cuda_device_context(self.device): - if getattr(self.first_stage_model, "comfy_handles_tiling", False) and dims in (2, 3): + if self.handles_tiling and dims in (2, 3): tiled_args = {} if tile_x is not None: tiled_args["tile_x"] = tile_x @@ -1204,8 +1207,6 @@ class VAE: else: pixels_in = pixels_in.to(self.device) out = self.first_stage_model.encode(pixels_in) - if isinstance(out, tuple): - out = out[0] out = out.to(self.output_device).to(dtype=self.vae_output_dtype()) if samples is None: samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) @@ -1225,7 +1226,7 @@ class VAE: if self.latent_dim == 3: tile = 256 overlap = tile // 4 - if getattr(self.first_stage_model, "comfy_handles_tiling", False): + if self.handles_tiling: samples = self._encode_tiled_owned(pixel_samples, tile_x=tile, tile_y=tile, overlap=overlap) else: samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) @@ -1234,9 +1235,8 @@ class VAE: else: samples = self.encode_tiled_(pixel_samples) - formatter = getattr(self.first_stage_model, "comfy_format_encoded", None) - if formatter is not None: - samples = formatter(samples) + if self.format_encoded is not None: + samples = self.format_encoded(samples) return samples def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): @@ -1268,7 +1268,7 @@ class VAE: elif dims == 2: samples = self.encode_tiled_(pixel_samples, **args) elif dims == 3: - if getattr(self.first_stage_model, "comfy_handles_tiling", False): + if self.handles_tiling: tiled_args = {} if tile_x is not None: tiled_args["tile_x"] = tile_x @@ -1298,9 +1298,8 @@ class VAE: samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args) - formatter = getattr(self.first_stage_model, "comfy_format_encoded", None) - if formatter is not None: - samples = formatter(samples) + if self.format_encoded is not None: + samples = self.format_encoded(samples) return samples def get_sd(self): @@ -1852,16 +1851,6 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl return (model, clip, vae) -def _set_model_config_inference_dtype(model_config, dtype, manual_cast_dtype, device): - set_dtype = model_config.set_inference_dtype - parameters = inspect.signature(set_dtype).parameters - supports_device = "device" in parameters or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in parameters.values()) - if supports_device: - set_dtype(dtype, manual_cast_dtype, device=device) - else: - set_dtype(dtype, manual_cast_dtype) - - def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False): sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic) @@ -1969,7 +1958,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) else: manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) - _set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device) + model_config.set_inference_dtype(unet_dtype, manual_cast_dtype, device=load_device) if model_config.clip_vision_prefix is not None: if output_clipvision: @@ -2110,7 +2099,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) else: manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) - _set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device) + model_config.set_inference_dtype(unet_dtype, manual_cast_dtype, device=load_device) if custom_operations is not None: model_config.custom_operations = custom_operations diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1ce5f8c91..5c849358e 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1688,6 +1688,10 @@ class SeedVR2(supported_models_base.BASE): unet_config = { "image_model": "seedvr2" } + required_keys = { + "{}positive_conditioning", + "{}negative_conditioning", + } latent_format = comfy.latent_formats.SeedVR2 vae_key_prefix = ["vae."] diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 572f9984e..e3a8e131f 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -54,13 +54,13 @@ class BASE: optimizations = {"fp8": False} @classmethod - def matches(s, unet_config, state_dict=None): + def matches(s, unet_config, state_dict=None, unet_key_prefix=""): for k in s.unet_config: if k not in unet_config or s.unet_config[k] != unet_config[k]: return False if state_dict is not None: for k in s.required_keys: - if k not in state_dict: + if k.format(unet_key_prefix) not in state_dict: return False return True diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index 978de3e41..1fb44ac36 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -3,7 +3,6 @@ from comfy_api.latest import ComfyExtension, io import torch import math import logging -from einops import rearrange import comfy.model_management import comfy.sample @@ -101,14 +100,6 @@ def _resolve_seedvr2_diffusion_model(model): return diffusion_model -def _apply_rope_freqs_float32_cast(diffusion_model): - """Cast every module's ``rope.freqs`` to float32; the per-tensor dtype check (not a sentinel attr) self-corrects across Comfy's unload/reload, which would otherwise restore the archived fp16/bf16 dtype.""" - for module in diffusion_model.modules(): - if hasattr(module, 'rope') and hasattr(module.rope, 'freqs'): - if module.rope.freqs.data.dtype != torch.float32: - module.rope.freqs.data = module.rope.freqs.data.to(torch.float32) - - def get_conditions(latent, latent_blur): t, h, w, c = latent.shape cond = torch.ones([t, h, w, c + 1], device=latent.device, dtype=latent.dtype) @@ -193,7 +184,7 @@ def _seedvr2_pad(images, upscaled_shorter_edge, node_name): images = images.reshape(b, t, c, new_h, new_w) images = cut_videos(images) - images_bthwc = rearrange(images, "b t c h w -> b t h w c") + images_bthwc = images.permute(0, 1, 3, 4, 2).contiguous() return io.NodeOutput(images_bthwc) @@ -265,12 +256,12 @@ class SeedVR2PostProcessing(io.ComfyNode): output_device = decoded_5d.device decoded_raw = cls._to_seedvr2_raw(decoded_5d) reference_raw = cls._to_seedvr2_raw(reference_5d) - decoded_flat = rearrange(decoded_raw, "b t h w c -> (b t) c h w") - reference_flat = rearrange(reference_raw, "b t h w c -> (b t) c h w") + decoded_flat = decoded_raw.permute(0, 1, 4, 2, 3).reshape(b * t, decoded_raw.shape[4], target_h, target_w) + reference_flat = reference_raw.permute(0, 1, 4, 2, 3).reshape(b * t, reference_raw.shape[4], target_h, target_w) output = cls._color_transfer_chunked( decoded_flat, reference_flat, output_device, color_correction_method, ) - output = rearrange(output, "(b t) c h w -> b t h w c", b=b, t=t) + output = output.reshape(b, t, output.shape[1], output.shape[2], output.shape[3]).permute(0, 1, 3, 4, 2) output = output.add(1.0).div(2.0).clamp(0.0, 1.0) elif color_correction_method == "none": output = decoded_5d @@ -359,7 +350,6 @@ class SeedVR2PostProcessing(io.ComfyNode): ) from e next_chunk_size = max(1, chunk_size // SEEDVR2_OOM_BACKOFF_DIVISOR) - comfy.model_management.soft_empty_cache() chunk_size = next_chunk_size @classmethod @@ -419,14 +409,14 @@ class SeedVR2PostProcessing(io.ComfyNode): if reference.shape[2] == height and reference.shape[3] == width: return reference b, t = reference.shape[:2] - reference_flat = rearrange(reference, "b t h w c -> (b t) c h w") + reference_flat = reference.permute(0, 1, 4, 2, 3).reshape(b * t, reference.shape[4], reference.shape[2], reference.shape[3]) resized = TVF.resize( reference_flat, size=(height, width), interpolation=InterpolationMode.BICUBIC, antialias=not (isinstance(reference_flat, torch.Tensor) and reference_flat.device.type == "mps"), ) - return rearrange(resized, "(b t) c h w -> b t h w c", b=b, t=t) + return resized.reshape(b, t, resized.shape[1], height, width).permute(0, 1, 3, 4, 2) class SeedVR2Conditioning(io.ComfyNode): @@ -471,39 +461,12 @@ class SeedVR2Conditioning(io.ComfyNode): pos_cond = model.positive_conditioning neg_cond = model.negative_conditioning - # Fail-loud guard against silently-wrong output when a - # DiT-only ``.safetensors`` (no ``positive_conditioning`` / - # ``negative_conditioning`` keys) is loaded via ``UNETLoader``. - # ``NaDiT.__init__`` zero-fills the buffers via ``torch.zeros`` (see - # ``comfy/ldm/seedvr/model.py``); ``load_state_dict(strict=False)`` - # leaves them at zero when the keys are absent. Detect that state - # here rather than at ``BaseModel.extra_conds`` (per sampling step, - # wasteful) or at the resolver helper (mixes structural shape with - # semantic content). Both buffers must be checked together — partial - # bake regressions could populate one but not the other. - if ( - pos_cond.float().abs().sum().item() == 0 - and neg_cond.float().abs().sum().item() == 0 - ): - raise RuntimeError( - f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: positive_conditioning " - f"and negative_conditioning buffers are zero-valued — model " - f"file appears to be a DiT-only export missing " - f"the SeedVR2 conditioning tensors. " - f"Re-bake the file with ``positive_conditioning`` (58, 5120) " - f"and ``negative_conditioning`` (64, 5120) keys at top level, " - f"or load via CheckpointLoaderSimple from a bundled " - f"checkpoint." - ) - - _apply_rope_freqs_float32_cast(model) - condition = torch.stack([get_conditions(c, c) for c in vae_conditioning]) condition = condition.movedim(-1, 1) latent = vae_conditioning.movedim(-1, 1) - latent = rearrange(latent, "b c t h w -> b (c t) h w") - condition = rearrange(condition, "b c t h w -> b (c t) h w") + latent = latent.reshape(latent.shape[0], latent.shape[1] * latent.shape[2], latent.shape[3], latent.shape[4]) + condition = condition.reshape(condition.shape[0], condition.shape[1] * condition.shape[2], condition.shape[3], condition.shape[4]) negative = [[neg_cond.unsqueeze(0), {"condition": condition}]] positive = [[pos_cond.unsqueeze(0), {"condition": condition}]] @@ -723,7 +686,7 @@ class SeedVR2ProgressiveSampler(io.ComfyNode): Drop-in replacement for ``KSampler`` in SeedVR2 native workflows that OOM on long sequences. The latent enters the sampler in SeedVR2's collapsed form ``(B, 16*T, H, W)`` (collapsed by ``SeedVR2Conditioning`` - at ``rearrange(b c t h w -> b (c t) h w)``); this node slices that + at ``reshape(b, c * t, h, w)``); this node slices that tensor along the temporal axis, runs the configured inner sampler sequentially per chunk against the standard ``comfy.sample.sample`` entry point, and concatenates per-chunk outputs back into a single @@ -882,7 +845,6 @@ class SeedVR2ProgressiveSampler(io.ComfyNode): "frames_per_chunk=%s.", attempt_frames_per_chunk, attempts[i + 1], ) - comfy.model_management.soft_empty_cache() # Short-circuit: total fits in one chunk -> standard path with no # chunking overhead. Output of this branch is byte-identical to the diff --git a/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py b/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py index 2a6e3d430..d36e50428 100644 --- a/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py +++ b/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py @@ -11,7 +11,6 @@ import importlib import sys from unittest.mock import MagicMock -import pytest import torch import torch.nn as nn @@ -53,7 +52,7 @@ def _import_nodes_seedvr_isolated(): mock_mm.WINDOWS = False sys.modules["comfy.model_management"] = mock_mm if sys.modules.get("comfy") is None: - import comfy as _comfy_pkg # noqa: F401 + importlib.import_module("comfy") comfy_pkg = sys.modules.get("comfy") if comfy_pkg is not None: setattr(comfy_pkg, "model_management", mock_mm) @@ -95,11 +94,10 @@ class _Block(nn.Module): class _DiffusionModel(nn.Module): """Stub diffusion model with N blocks and pos/neg conditioning buffers.""" - def __init__(self, n_blocks=3, zero_conditioning=False, conditioning_dtype=torch.float32): + def __init__(self, n_blocks=3, conditioning_dtype=torch.float32): super().__init__() self.blocks = nn.ModuleList([_Block() for _ in range(n_blocks)]) - pos = torch.zeros if zero_conditioning else torch.ones - self.register_buffer("positive_conditioning", pos((2, 4), dtype=conditioning_dtype)) + self.register_buffer("positive_conditioning", torch.ones((2, 4), dtype=conditioning_dtype)) self.register_buffer("negative_conditioning", torch.zeros((3, 4), dtype=conditioning_dtype)) @@ -185,29 +183,3 @@ def test_seedvr2_conditioning_returns_packed_input_latent_deterministically(): ) finally: restore() - - -def test_seedvr2_conditioning_fails_loud_on_zero_buffers(): - nodes_seedvr, restore = _import_nodes_seedvr_isolated() - try: - diffusion_model = _DiffusionModel(zero_conditioning=True) - patcher = _ModelPatcher(diffusion_model) - vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))} - - with pytest.raises(RuntimeError) as excinfo: - nodes_seedvr.SeedVR2Conditioning.execute( - patcher, vae_conditioning, - ) - - message = str(excinfo.value) - assert message.startswith( - nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX - ), ( - "Fail-loud message must use the standard " - "_SEEDVR2_INVALID_MODEL_MSG_PREFIX so callers/log scrapers " - f"can match it. Got: {message!r}" - ) - assert "positive_conditioning" in message - assert "negative_conditioning" in message - finally: - restore() diff --git a/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py b/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py index a27a8f8df..6c821136d 100644 --- a/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py +++ b/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py @@ -1,5 +1,6 @@ from unittest.mock import patch +import pytest import torch from comfy.cli_args import args as cli_args @@ -32,26 +33,19 @@ def test_seedvr2_post_processing_oom_error_uses_color_correction_method(monkeypa monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu")) monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000) - monkeypatch.setattr(nodes_seedvr.comfy.model_management, "soft_empty_cache", lambda: None) with patch.object(nodes_seedvr, "lab_color_transfer", _lab): - try: + with pytest.raises(RuntimeError) as excinfo: nodes_seedvr.SeedVR2PostProcessing._color_transfer_chunked( decoded, reference, torch.device("cpu"), "lab", ) - except RuntimeError as exc: - assert "color_correction_method=lab" in str(exc) - assert " method=lab" not in str(exc) - else: - raise AssertionError("expected RuntimeError for one-frame LAB OOM") + assert "color_correction_method=lab" in str(excinfo.value) + assert " method=lab" not in str(excinfo.value) def test_seedvr2_post_processing_unknown_color_correction_method_raises(): decoded = torch.zeros(1, 2, 4, 4, 3) original = torch.zeros(1, 2, 4, 4, 3) - try: + with pytest.raises(ValueError) as excinfo: nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, "bogus") - except ValueError as exc: - assert "color_correction_method" in str(exc) - else: - raise AssertionError("expected ValueError for unknown color_correction_method") + assert "color_correction_method" in str(excinfo.value) diff --git a/tests-unit/comfy_test/model_detection_test.py b/tests-unit/comfy_test/model_detection_test.py index 109e2b13b..587c393c9 100644 --- a/tests-unit/comfy_test/model_detection_test.py +++ b/tests-unit/comfy_test/model_detection_test.py @@ -2,7 +2,7 @@ from collections import defaultdict import torch -from comfy.model_detection import detect_unet_config, model_config_from_unet_config +from comfy.model_detection import detect_unet_config, model_config_from_unet, model_config_from_unet_config import comfy.supported_models @@ -76,21 +76,31 @@ def _make_flux_schnell_comfyui_sd(): def _make_seedvr2_7b_separate_mm_sd(): return { "blocks.35.mlp.vid.proj_in.weight": torch.empty(1, 3072), + "positive_conditioning": torch.empty(58, 5120), + "negative_conditioning": torch.empty(64, 5120), } def _make_seedvr2_7b_shared_mm_sd(): return { "blocks.35.mlp.all.proj_in_gate.weight": torch.empty(1, 1), + "positive_conditioning": torch.empty(58, 5120), + "negative_conditioning": torch.empty(64, 5120), } def _make_seedvr2_3b_shared_mm_sd(): return { "blocks.31.mlp.all.proj_in_gate.weight": torch.empty(1, 1), + "positive_conditioning": torch.empty(58, 5120), + "negative_conditioning": torch.empty(64, 5120), } +def _add_model_diffusion_prefix(sd): + return {f"model.diffusion_model.{k}": v for k, v in sd.items()} + + class TestModelDetection: """Verify that first-match model detection selects the correct model based on list ordering and unet_config specificity.""" @@ -182,6 +192,20 @@ class TestModelDetection: assert unet_config["num_layers"] == 32 assert unet_config["mlp_type"] == "swiglu" + def test_seedvr2_model_match_requires_conditioning_tensors(self): + sd = _make_seedvr2_7b_shared_mm_sd() + unet_config = detect_unet_config(sd, "") + + assert type(model_config_from_unet_config(unet_config, sd)).__name__ == "SeedVR2" + + del sd["positive_conditioning"] + assert model_config_from_unet_config(unet_config, sd) is None + + def test_seedvr2_model_match_accepts_full_checkpoint_prefix(self): + sd = _add_model_diffusion_prefix(_make_seedvr2_7b_shared_mm_sd()) + + assert type(model_config_from_unet(sd, "model.diffusion_model.")).__name__ == "SeedVR2" + def test_unet_config_and_required_keys_combination_is_unique(self): """Each model in the registry must have a unique combination of ``unet_config`` and ``required_keys``. If two models share the same diff --git a/tests-unit/comfy_test/test_seedvr2_internals.py b/tests-unit/comfy_test/test_seedvr2_internals.py index dd3121428..966e9465d 100644 --- a/tests-unit/comfy_test/test_seedvr2_internals.py +++ b/tests-unit/comfy_test/test_seedvr2_internals.py @@ -103,7 +103,7 @@ def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypa heads=heads, head_dim=head_dim, qk_bias=False, - qk_norm=seedvr_model.CustomRMSNorm, + qk_norm=comfy_ops.disable_weight_init.RMSNorm, qk_norm_eps=1e-6, rope_type=None, rope_dim=head_dim, diff --git a/tests-unit/comfy_test/test_seedvr2_model.py b/tests-unit/comfy_test/test_seedvr2_model.py index feae2211f..06b2f1564 100644 --- a/tests-unit/comfy_test/test_seedvr2_model.py +++ b/tests-unit/comfy_test/test_seedvr2_model.py @@ -26,6 +26,7 @@ import comfy.ldm.seedvr.model # noqa: E402 import comfy.ldm.seedvr.model as seedvr_model # noqa: E402 import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402 import comfy.model_management # noqa: E402 +import comfy.ops as comfy_ops # noqa: E402 import comfy.sample # noqa: E402 import comfy.sd as sd_mod # noqa: E402 import nodes as nodes_mod # noqa: E402 @@ -81,6 +82,7 @@ def _capture_last_layer_flags(monkeypatch, vid_dim: int, txt_in_dim: int) -> lis txt_in_dim=txt_in_dim, heads=24, mm_layers=3, + operations=comfy_ops.disable_weight_init, ) return flags @@ -140,6 +142,46 @@ class _DecodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper): return torch.zeros(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device) +def test_seedvr2_wrapper_public_encode_returns_tensor(monkeypatch): + raw_latent = torch.full((1, 16, 1, 4, 5), 2.0) + seen_shapes = [] + + def base_encode(self, x): + seen_shapes.append(tuple(x.shape)) + return raw_latent.to(device=x.device, dtype=x.dtype) + + monkeypatch.setattr(seedvr_vae_mod.VideoAutoencoderKL, "encode", base_encode) + + vae = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(seedvr_vae_mod.VideoAutoencoderKLWrapper) + nn.Module.__init__(vae) + vae._dummy = nn.Parameter(torch.zeros((), dtype=torch.float32)) + + latent = vae.encode(torch.zeros(1, 3, 32, 40)) + + assert type(latent) is torch.Tensor + assert tuple(latent.shape) == (1, 16, 4, 5) + assert seen_shapes == [(1, 3, 1, 32, 40)] + + +def test_seedvr2_wrapper_private_encode_helper_keeps_raw_latent(monkeypatch): + raw_latent = torch.full((1, 16, 1, 4, 5), 3.0) + + def base_encode(self, x): + return raw_latent.to(device=x.device, dtype=x.dtype) + + monkeypatch.setattr(seedvr_vae_mod.VideoAutoencoderKL, "encode", base_encode) + + vae = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(seedvr_vae_mod.VideoAutoencoderKLWrapper) + nn.Module.__init__(vae) + vae._dummy = nn.Parameter(torch.zeros((), dtype=torch.float32)) + + latent, raw = vae._encode_with_raw_latent(torch.zeros(1, 3, 32, 40)) + + assert tuple(latent.shape) == (1, 16, 4, 5) + assert tuple(raw.shape) == (1, 16, 1, 4, 5) + assert torch.equal(raw, raw_latent) + + def _make_vae(wrapper): vae = sd_mod.VAE.__new__(sd_mod.VAE) vae.first_stage_model = wrapper @@ -155,6 +197,8 @@ def _make_vae(wrapper): vae.extra_1d_channel = None vae.crop_input = False vae.not_video = False + vae.handles_tiling = isinstance(wrapper, seedvr_vae_mod.VideoAutoencoderKLWrapper) + vae.format_encoded = wrapper.comfy_format_encoded vae.patcher = _Patcher() vae.process_input = lambda image: image vae.process_output = lambda image: image.add(1.0).div(2.0).clamp(0.0, 1.0) diff --git a/tests-unit/comfy_test/test_seedvr2_vae_tiled.py b/tests-unit/comfy_test/test_seedvr2_vae_tiled.py index ced2fe34f..0d3c97e4a 100644 --- a/tests-unit/comfy_test/test_seedvr2_vae_tiled.py +++ b/tests-unit/comfy_test/test_seedvr2_vae_tiled.py @@ -1,6 +1,7 @@ from contextlib import ExitStack from unittest.mock import MagicMock, patch +import pytest import torch import torch.nn as nn @@ -21,8 +22,6 @@ from comfy.ldm.seedvr.vae import MemoryState, tiled_vae # noqa: E402 def test_runtime_decode_zero_temporal_size_disables_slicing_for_call(): - from comfy.ldm.seedvr.vae import MemoryState, VideoAutoencoderKL, tiled_vae - class StubVAEModel(torch.nn.Module): def __init__(self): super().__init__() @@ -37,9 +36,9 @@ def test_runtime_decode_zero_temporal_size_disables_slicing_for_call(): def decode_(self, t_chunk): self.decode_min_sizes.append(self.slicing_latent_min_size) - return VideoAutoencoderKL.slicing_decode(self, t_chunk) + return vae_mod.VideoAutoencoderKL.slicing_decode(self, t_chunk) - def _decode(self, z, memory_state=MemoryState.DISABLED): + def _decode(self, z, memory_state=MemoryState.DISABLED, memory_cache=None): self.memory_states.append(memory_state) b, c, d, h, w = z.shape return torch.zeros((b, 3, d, h * 8, w * 8), dtype=z.dtype) @@ -68,8 +67,6 @@ def test_runtime_decode_zero_temporal_size_disables_slicing_for_call(): def test_zero_temporal_size_preserves_min_size_when_encode_raises(): - from comfy.ldm.seedvr.vae import tiled_vae - class RaisingVAEModel(torch.nn.Module): def __init__(self): super().__init__() @@ -85,8 +82,7 @@ def test_zero_temporal_size_preserves_min_size_when_encode_raises(): vae = RaisingVAEModel() x = torch.zeros((1, 3, 12, 64, 64), dtype=torch.float32) - raised = False - try: + with pytest.raises(RuntimeError, match="simulated encode failure"): tiled_vae( x, vae, @@ -96,15 +92,43 @@ def test_zero_temporal_size_preserves_min_size_when_encode_raises(): temporal_overlap=0, encode=True, ) - except RuntimeError as exc: - if "simulated encode failure" not in str(exc): - raise - raised = True - assert raised assert vae.slicing_sample_min_size == 4 +def test_tiled_vae_encode_uses_tensor_return_without_indexing(): + class TensorEncodeVAEModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.slicing_sample_min_size = 4 + self.spatial_downsample_factor = 8 + self.temporal_downsample_factor = 4 + self.device = torch.device("cpu") + self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32)) + self.calls = [] + + def encode(self, t_chunk): + self.calls.append(tuple(t_chunk.shape)) + b, _, _, h, w = t_chunk.shape + return torch.ones((b, 16, 1, h // 8, w // 8), dtype=t_chunk.dtype) + + vae = TensorEncodeVAEModel() + x = torch.zeros((2, 3, 1, 64, 64), dtype=torch.float32) + + out = tiled_vae( + x, + vae, + tile_size=(64, 64), + tile_overlap=(0, 0), + temporal_size=0, + temporal_overlap=0, + encode=True, + ) + + assert vae.calls == [(2, 3, 1, 64, 64)] + assert tuple(out.shape) == (2, 16, 1, 8, 8) + + # --------------------------------------------------------------------------- # From test_seedvr_vae_tiled_temporal_slicing.py # --------------------------------------------------------------------------- @@ -126,7 +150,7 @@ class _SlicingDecodeVAE(nn.Module): self.decode_min_sizes.append(self.slicing_latent_min_size) return vae_mod.VideoAutoencoderKL.slicing_decode(self, z) - def _decode(self, z, memory_state=MemoryState.DISABLED): + def _decode(self, z, memory_state=MemoryState.DISABLED, memory_cache=None): self.memory_states.append(memory_state) x = z[:, :1].repeat( 1, @@ -205,6 +229,8 @@ def _make_vae(first_stage_model, latent_channels, latent_dim): vae.latent_dim = latent_dim vae.vae_output_dtype = lambda: torch.float32 vae.spacial_compression_decode = lambda: 8 + vae.handles_tiling = isinstance(first_stage_model, seedvr_vae_mod.VideoAutoencoderKLWrapper) + vae.format_encoded = None vae.process_input = lambda x: x vae.process_output = lambda x: x vae.throw_exception_if_invalid = lambda: None @@ -240,7 +266,6 @@ def test_4d_seedvr2_latent_routes_to_owned_decode_tiled(): def test_4d_non_seedvr2_latent_still_routes_to_generic_decode_tiled(): first_stage = MagicMock() - first_stage.comfy_handles_tiling = False first_stage.decode = MagicMock(side_effect=_force_oom) vae = _make_vae(first_stage, latent_channels=4, latent_dim=2) seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64)) @@ -273,6 +298,8 @@ def _populate_common_vae_attrs_fallback(vae): vae.not_video = False vae.crop_input = False vae.pad_channel_value = None + vae.handles_tiling = isinstance(vae.first_stage_model, seedvr_vae_mod.VideoAutoencoderKLWrapper) + vae.format_encoded = None vae.vae_output_dtype = lambda: torch.float32 vae.spacial_compression_encode = lambda: 8 @@ -295,7 +322,6 @@ def _make_seedvr2_vae_fallback(): def _make_non_seedvr2_vae_fallback(): vae = sd_mod.VAE.__new__(sd_mod.VAE) vae.first_stage_model = MagicMock() - vae.first_stage_model.comfy_handles_tiling = False _populate_common_vae_attrs_fallback(vae) return vae From 77d42ed7e9bd59d5edc07a64df13a6f1ef5b2ab5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 1 Jul 2026 22:19:37 -0400 Subject: [PATCH 10/12] Remove SeedVR2ProgressiveSampler. --- comfy/ldm/seedvr/constants.py | 19 +- comfy_extras/nodes_seedvr.py | 514 ------------------ .../comfy_extras_test/test_seedvr2_nodes.py | 2 +- .../test_seedvr_progressive_sampler.py | 95 ---- 4 files changed, 4 insertions(+), 626 deletions(-) delete mode 100644 tests-unit/comfy_test/test_seedvr_progressive_sampler.py diff --git a/comfy/ldm/seedvr/constants.py b/comfy/ldm/seedvr/constants.py index 71b71d4ad..df91c7772 100644 --- a/comfy/ldm/seedvr/constants.py +++ b/comfy/ldm/seedvr/constants.py @@ -8,26 +8,13 @@ Provenance prefixes: ISO / CIE values; cite the standard. """ -# -------------------------------------------------------------------------------------- -# A. Progressive-sampler chunk-size law (SEEDVR2 - this integration's VRAM experiment) -# n_max(frames/chunk) = SEEDVR2_CHUNK_FRAMES_PER_GB * (free_GB - SEEDVR2_CHUNK_GB_MARGIN) -# rounded to the 4n+1 grid. Fit on 22 blocked-5090 cells, validated on a real RTX 4070 -# (3b and 7b). Resolution-independent (the VAE tiling sets the wall, not the DiT). -# -------------------------------------------------------------------------------------- -SEEDVR2_CHUNK_GB_MARGIN = 3 # fixed VRAM overhead before chunks scale (GiB) -SEEDVR2_CHUNK_FRAMES_PER_GB = 4 # empirical slope: pixel frames admitted per free GiB - -# -------------------------------------------------------------------------------------- -# B. Fork heuristics (SEEDVR2 - this integration) -# -------------------------------------------------------------------------------------- SEEDVR2_7B_VID_DIM = 3072 # runtime 3b-vs-7b sentinel; tested against vid_dim. # (3072 is ByteDance's 7b vid_dim; the sentinel use is ours.) -SEEDVR2_OOM_BACKOFF_DIVISOR = 2 # auto-chunk OOM retry: halve the chunk and retry. +SEEDVR2_OOM_BACKOFF_DIVISOR = 2 # OOM retry backoff: halve the chunk and retry. SEEDVR2_DTYPE_BYTES_FLOOR = 4 # per-element byte floor for memory math (fp32 worst case). SEEDVR2_7B_MLP_CHUNK = 8192 # 7b MLP token-chunk to bound peak VRAM. SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS = 4096 # partial-RoPE application token-chunk. SEEDVR2_LATENT_CHANNELS = 16 # SeedVR2 latent channel count (== BYTEDANCE latent_channels). -SEEDVR2_COND_CHANNELS = 17 # conditioning channels = vid_in_channels(33) - latent(16). # Color-correction memory model (fork tuning; per-frame VRAM estimate for chunk sizing) SEEDVR2_COLOR_MEM_HEADROOM = 0.75 # fraction of free VRAM usable per color-correction chunk. @@ -36,7 +23,7 @@ SEEDVR2_WAVELET_SCALE_MULTIPLIER = 10 # per-frame byte multiplier, wavelet path SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 # per-frame byte multiplier, AdaIN path. # -------------------------------------------------------------------------------------- -# C. ByteDance config / source (BYTEDANCE - cite ByteDance-Seed/SeedVR) +# ByteDance config / source (BYTEDANCE - cite ByteDance-Seed/SeedVR) # -------------------------------------------------------------------------------------- BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57 (scaling_factor); latent denorm. BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 # infer.py (shifting_factor default); latent denorm shift. @@ -56,7 +43,7 @@ BYTEDANCE_ROPE_MAX_FREQ = 256 # dit_v2/rope.py:31 (pixel-RoPE max freq BYTEDANCE_SINUSOIDAL_DIM = 256 # dit_3b/nadit.py:120 (timestep sinusoidal embed dim). # -------------------------------------------------------------------------------------- -# D. Published standards (cite the literature) +# Published standards (cite the literature) # -------------------------------------------------------------------------------------- ROPE_THETA = 10000 # RoPE base; Su et al., "RoFormer", arXiv:2104.09864. diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index 1fb44ac36..bf5b3c15c 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -1,12 +1,8 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, io import torch -import math -import logging import comfy.model_management -import comfy.sample -import comfy.samplers from comfy.ldm.seedvr.color_fix import ( adain_color_transfer, lab_color_transfer, @@ -14,10 +10,7 @@ from comfy.ldm.seedvr.color_fix import ( ) from comfy.ldm.seedvr.constants import ( SEEDVR2_ADAIN_SCALE_MULTIPLIER, - SEEDVR2_CHUNK_FRAMES_PER_GB, - SEEDVR2_CHUNK_GB_MARGIN, SEEDVR2_COLOR_MEM_HEADROOM, - SEEDVR2_COND_CHANNELS, SEEDVR2_DTYPE_BYTES_FLOOR, SEEDVR2_LAB_SCALE_MULTIPLIER, SEEDVR2_LATENT_CHANNELS, @@ -39,40 +32,6 @@ _SEEDVR2_INVALID_MODEL_MSG_PREFIX = ( _ATTR_MISSING = object() -def _seedvr2_vram_seed_frames_per_chunk(free_bytes, t_pixel): - """Predict the largest 4n+1 pixel-frame chunk that fits in free_bytes.""" - free_gb = free_bytes / (1024 ** 3) - predicted = SEEDVR2_CHUNK_FRAMES_PER_GB * (free_gb - SEEDVR2_CHUNK_GB_MARGIN) - # round (not floor) to 4n+1: the fit's central prediction lands on measured n_max - n = round((predicted - 1) / 4) - seed = 4 * int(n) + 1 - seed = max(1, min(seed, t_pixel)) - return seed - - -def _seedvr2_auto_chunk_attempts(t_latent, t_pixel, frames_per_chunk): - """Return stricter 4n+1 frame chunk sizes for auto OOM retries.""" - attempts = [frames_per_chunk] - current_chunk_latent = ( - t_latent if t_pixel <= frames_per_chunk - else (frames_per_chunk - 1) // 4 + 1 - ) - current_chunk_count = max(1, math.ceil(t_latent / current_chunk_latent)) - seen = {frames_per_chunk} - - for target_chunks in range(max(2, current_chunk_count + 1), t_latent + 1): - chunk_latent = max(1, math.ceil(t_latent / target_chunks)) - candidate = 4 * (chunk_latent - 1) + 1 - if candidate in seen: - continue - if candidate >= attempts[-1]: - continue - attempts.append(candidate) - seen.add(candidate) - - return attempts - - def _resolve_seedvr2_diffusion_model(model): """Resolve ``model.model.diffusion_model``, failing loud via the ``_ATTR_MISSING`` sentinel so each of the four modes (model/diffusion_model missing vs None) gives an accurate message.""" inner = getattr(model, "model", _ATTR_MISSING) @@ -473,478 +432,6 @@ class SeedVR2Conditioning(io.ComfyNode): return io.NodeOutput(model_patcher, positive, negative, {"samples": latent}) -def _slice_collapsed_4d_along_t(tensor_4d: torch.Tensor, t_start: int, - t_end: int, channels: int) -> torch.Tensor: - """Slice collapsed ``(B, channels*T, H, W)`` along latent T: reshape (accepts non-contiguous inputs), slice, ``.contiguous()`` (T-slice of 5D is a non-contiguous view; re-collapse needs contiguous), re-collapse.""" - B, CT, H, W = tensor_4d.shape - if CT % channels != 0: - raise ValueError( - f"_slice_collapsed_4d_along_t: collapsed channel dim {CT} is not " - f"divisible by channels={channels}; tensor shape {tuple(tensor_4d.shape)}." - ) - T = CT // channels - if not (0 <= t_start < t_end <= T): - raise ValueError( - f"_slice_collapsed_4d_along_t: slice [{t_start}:{t_end}] out of " - f"range for T={T}." - ) - new_T = t_end - t_start - sliced = tensor_4d.reshape(B, channels, T, H, W)[:, :, t_start:t_end, :, :].contiguous() - return sliced.reshape(B, channels * new_T, H, W) - - -def _slice_seedvr2_cond_along_t(cond_list, t_start: int, t_end: int): - """Return a new conditioning list with each entry's ``options["condition"]`` (collapsed ``(B, 17*T, H, W)``) sliced along latent T; text tensors, other option keys, and condition-less entries pass through unchanged and inputs are not mutated.""" - new_list = [] - for entry in cond_list: - text_cond, options = entry[0], entry[1] - if "condition" not in options: - new_list.append(entry) - continue - new_options = options.copy() - new_options["condition"] = _slice_collapsed_4d_along_t( - new_options["condition"], t_start, t_end, - SEEDVR2_COND_CHANNELS, - ) - new_list.append([text_cond, new_options]) - return new_list - - -def _slice_seedvr2_noise_mask_along_t(noise_mask: torch.Tensor, - samples_4d: torch.Tensor, - t_start: int, - t_end: int): - """Slice only masks already expanded to collapsed ``(B, 16*T, H, W)``; pass standard ``(B, 1, H, W)`` ``SetLatentNoiseMask`` outputs through for KSampler to expand.""" - if noise_mask.ndim == samples_4d.ndim and noise_mask.shape[1] == samples_4d.shape[1]: - return _slice_collapsed_4d_along_t( - noise_mask, t_start, t_end, SEEDVR2_LATENT_CHANNELS, - ) - return noise_mask - - -def _concat_chunks_along_t(chunks_4d, channels: int) -> torch.Tensor: - """Concatenate collapsed ``(B, channels*T_i, H, W)`` chunks along latent T: un-collapse to 5D, cat on ``dim=2``, re-collapse to 4D.""" - if len(chunks_4d) == 0: - raise ValueError("_concat_chunks_along_t: empty chunk list.") - fives = [] - for ch in chunks_4d: - B, CT, H, W = ch.shape - if CT % channels != 0: - raise ValueError( - f"_concat_chunks_along_t: chunk shape {tuple(ch.shape)} " - f"channel dim {CT} not divisible by channels={channels}." - ) - T = CT // channels - fives.append(ch.reshape(B, channels, T, H, W)) - cat = torch.cat(fives, dim=2).contiguous() - B, C, T_total, H, W = cat.shape - return cat.reshape(B, C * T_total, H, W) - - -def _hann_blend_weights_1d(overlap: int, device, dtype) -> torch.Tensor: - """1D length-``overlap`` crossfade weights for the previous chunk (current = ``1 - w_prev``): - Hann window with a ``[1/3, 2/3]`` dead-band for ``overlap >= 3``, linear ramp for ``overlap < 3`` - (dead-band would collapse a tiny transition). Window shape matched to the reference - overlapping-frame blend for parity; caller broadcasts across ``(B, C, T_overlap, H, W)``. - """ - if overlap < 1: - raise ValueError( - f"_hann_blend_weights_1d: overlap must be >= 1; got {overlap}." - ) - if overlap >= 3: - t = torch.linspace(0.0, 1.0, steps=overlap, device=device, dtype=dtype) - blend_start = 1.0 / 3.0 - blend_end = 2.0 / 3.0 - u = ((t - blend_start) / (blend_end - blend_start)).clamp(0.0, 1.0) - return 0.5 + 0.5 * torch.cos(torch.pi * u) - return torch.linspace(1.0, 0.0, steps=overlap, device=device, dtype=dtype) - - -def _blend_overlap_region(prev_tail_5d: torch.Tensor, - cur_head_5d: torch.Tensor) -> torch.Tensor: - """Blend two equal-shape 5D ``(B, C, T_overlap, H, W)`` tensors with a 1D Hann/linear T-ramp: ``prev_tail_5d`` takes the descending weight, ``cur_head_5d`` takes ``1 - w_prev`` (caller ensures matching shape/dtype/device).""" - if prev_tail_5d.shape != cur_head_5d.shape: - raise ValueError( - f"_blend_overlap_region: shape mismatch " - f"prev {tuple(prev_tail_5d.shape)} vs " - f"cur {tuple(cur_head_5d.shape)}." - ) - overlap = int(prev_tail_5d.shape[2]) - w_prev_1d = _hann_blend_weights_1d( - overlap, prev_tail_5d.device, prev_tail_5d.dtype, - ) - # Reshape to (1, 1, overlap, 1, 1) for broadcast across B, C, H, W. - w_prev = w_prev_1d.view(1, 1, overlap, 1, 1) - w_cur = 1.0 - w_prev - return prev_tail_5d * w_prev + cur_head_5d * w_cur - - -def _concat_chunks_with_overlap_blend(chunk_specs, channels: int, - overlap_latent: int) -> torch.Tensor: - """Concatenate overlapping ``(t_start, t_end, chunk_4d)`` specs (source-latent T coords) into one collapsed 4D tensor, Hann/linear-blending overlaps; ``overlap_latent == 0`` fast-paths to plain concat (bit-identical to ``_concat_chunks_along_t``). Each blend uses the actual width ``min(prev_end - cur_start, chunk length)``, smaller than ``overlap_latent`` for a runt final chunk.""" - if len(chunk_specs) == 0: - raise ValueError("_concat_chunks_with_overlap_blend: empty chunk list.") - if overlap_latent < 0: - raise ValueError( - f"_concat_chunks_with_overlap_blend: overlap_latent must be " - f">= 0; got {overlap_latent}." - ) - - # Validate channel divisibility once and capture per-chunk T. - chunk_5d = [] - for t_start, t_end, ch in chunk_specs: - B, CT, H, W = ch.shape - if CT % channels != 0: - raise ValueError( - f"_concat_chunks_with_overlap_blend: chunk shape " - f"{tuple(ch.shape)} channel dim {CT} not divisible " - f"by channels={channels}." - ) - T = CT // channels - if t_end - t_start != T: - raise ValueError( - f"_concat_chunks_with_overlap_blend: chunk T={T} mismatches " - f"declared range [{t_start}:{t_end}]." - ) - chunk_5d.append((t_start, t_end, ch.reshape(B, channels, T, H, W))) - - if overlap_latent == 0: - # Fast path: pure concat in the caller-provided chunk order. - return _concat_chunks_along_t( - [c.reshape(c.shape[0], channels * c.shape[2], c.shape[3], c.shape[4]) - for _, _, c in chunk_5d], - channels, - ) - - T_total = max(t_end for _, t_end, _ in chunk_5d) - first_5d = chunk_5d[0][2] - B = first_5d.shape[0] - H = first_5d.shape[3] - W = first_5d.shape[4] - result = torch.empty( - (B, channels, T_total, H, W), - device=first_5d.device, dtype=first_5d.dtype, - ) - filled_until = 0 - for i, (cs, ce, ct_5d) in enumerate(chunk_5d): - chunk_T = int(ct_5d.shape[2]) - if i == 0: - result[:, :, cs:ce, :, :] = ct_5d - filled_until = ce - continue - # Overlap region width is bounded by both the previous fill - # frontier and the current chunk's actual length (for runt - # final chunks shorter than the configured overlap). - overlap_len = min(filled_until - cs, chunk_T) - if overlap_len > 0: - prev_tail = result[:, :, cs:cs + overlap_len, :, :].contiguous() - cur_head = ct_5d[:, :, :overlap_len, :, :].contiguous() - blended = _blend_overlap_region(prev_tail, cur_head) - result[:, :, cs:cs + overlap_len, :, :] = blended - tail_start = cs + overlap_len - tail_end = ce - if tail_end > tail_start: - result[:, :, tail_start:tail_end, :, :] = ( - ct_5d[:, :, overlap_len:, :, :] - ) - else: - # Disjoint chunks (overlap_latent set but this pair did not - # actually overlap, e.g. step_latent equal to chunk_latent - # in a degenerate config). Treat as concat. - result[:, :, cs:ce, :, :] = ct_5d - filled_until = ce - - return result.contiguous().reshape(B, channels * T_total, H, W) - - -def _run_standard_sample(model, seed: int, steps: int, cfg: float, - sampler_name: str, scheduler: str, - positive, negative, latent: dict, - denoise: float) -> dict: - """Single-shot mirror of ``nodes.py:common_ksampler`` (seed -> noise, ``comfy.sample.sample``, latent dict); used by the ProgressiveSampler short-circuit when the whole sequence fits one chunk.""" - samples_in = latent["samples"] - samples_in = comfy.sample.fix_empty_latent_channels( - model, samples_in, latent.get("downscale_ratio_spacial", None), - ) - batch_inds = latent.get("batch_index", None) - noise = comfy.sample.prepare_noise(samples_in, seed, batch_inds) - noise_mask = latent.get("noise_mask", None) - samples = comfy.sample.sample( - model, noise, steps, cfg, sampler_name, scheduler, - positive, negative, samples_in, - denoise=denoise, noise_mask=noise_mask, seed=seed, - ) - out = latent.copy() - out.pop("downscale_ratio_spacial", None) - out["samples"] = samples - return out - - -class SeedVR2ProgressiveSampler(io.ComfyNode): - """Sequential temporal chunking sampler for SeedVR2 native. - - Drop-in replacement for ``KSampler`` in SeedVR2 native workflows that - OOM on long sequences. The latent enters the sampler in SeedVR2's - collapsed form ``(B, 16*T, H, W)`` (collapsed by ``SeedVR2Conditioning`` - at ``reshape(b, c * t, h, w)``); this node slices that - tensor along the temporal axis, runs the configured inner sampler - sequentially per chunk against the standard ``comfy.sample.sample`` - entry point, and concatenates per-chunk outputs back into a single - ``(B, 16*T_total, H, W)`` latent. - - ``frames_per_chunk`` is expressed in pixel-frame units to match the - SeedVR2 4n+1 constraint enforced upstream by ``cut_videos`` and the - VAE's ``temporal_downsample_factor=4``. A pixel chunk size ``F`` - maps to ``(F - 1) // 4 + 1`` latent-frame chunks. - - Determinism contract: a single noise tensor is generated once from - the user seed and sliced per chunk (rather than re-seeding each - chunk), so a workflow that fits in a single chunk produces output - identical to a workflow that fits in N chunks at the same seed, - modulo the inherent T-axis chunk-boundary independence of the model. - """ - - @classmethod - def define_schema(cls): - return io.Schema( - node_id="SeedVR2ProgressiveSampler", - display_name="Sample SeedVR2 (Progressive)", - category="sampling", - description="Sample a SeedVR2 latent in sequential temporal chunks to allow longer videos to fit into VRAM via frame blending the resulting upscaled latents.", - search_aliases=["seedvr2", "upscale", "video upscale", "sampler", "chunk"], - inputs=[ - io.Model.Input("model", tooltip="The model used for denoising the input latent."), - io.Int.Input("seed", default=0, min=0, - max=0xffffffffffffffff, - control_after_generate=True, - tooltip="The random seed used for creating the noise."), - io.Int.Input("steps", default=20, min=1, max=10000, - tooltip="The number of steps used in the denoising process."), - io.Float.Input("cfg", default=1.0, min=0.0, max=100.0, - step=0.1, round=0.01, - tooltip="The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."), - io.Combo.Input("sampler_name", - options=comfy.samplers.SAMPLER_NAMES, - tooltip="The algorithm used when sampling, this can affect the quality, speed, and style of the generated output."), - io.Combo.Input("scheduler", - options=comfy.samplers.SCHEDULER_NAMES, - tooltip="The scheduler controls how noise is gradually removed to form the image."), - io.Conditioning.Input("positive", - tooltip="The conditioning describing the attributes you want to include in the image."), - io.Conditioning.Input("negative", - tooltip="The conditioning describing the attributes you want to exclude from the image."), - io.Latent.Input("latent", - tooltip="The latent image to denoise."), - io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, - step=0.01, - tooltip="The amount of denoising applied, lower values will maintain the structure of the initial image allowing for image to image sampling."), - io.Int.Input("frames_per_chunk", default=21, min=1, - max=16384, step=4, - tooltip="Pixel frames per temporal chunk (4n+1: 1, 5, 9, 13, ...)."), - io.Int.Input("temporal_overlap", default=0, min=0, - max=16384, - tooltip="Latent frames blended between adjacent chunks to hide the seam; 0 = no blend."), - io.Combo.Input("chunking_mode", - options=["manual", "auto"], - default="manual", - tooltip="manual = use frames_per_chunk exactly; auto = shrink the chunk until it fits in VRAM."), - ], - outputs=[io.Latent.Output(display_name="latent", tooltip="The upscaled latent.")], - ) - - @classmethod - def execute(cls, model, seed, steps, cfg, sampler_name, scheduler, - positive, negative, latent, denoise, - frames_per_chunk, temporal_overlap, - chunking_mode="manual") -> io.NodeOutput: - # 4n+1 validation in pixel-frame domain. The SeedVR2 native pipeline - # requires pixel-frame counts of the form 4n+1 (1, 5, 9, 13, ...), - # imposed at ``cut_videos`` upstream and propagated through the VAE's - # temporal_downsample_factor=4. Reject violations explicitly before - # any model invocation; a silent rounding would mis-align chunk - # boundaries with the 4n+1 lattice. - if frames_per_chunk < 1 or (frames_per_chunk - 1) % 4 != 0: - raise ValueError( - f"SeedVR2ProgressiveSampler: frames_per_chunk must be a " - f"4n+1 pixel-frame count (1, 5, 9, 13, 17, 21, ...); " - f"got {frames_per_chunk}." - ) - - samples_4d = latent["samples"] - if torch.count_nonzero(samples_4d) == 0: - raise ValueError( - "SeedVR2ProgressiveSampler: input latent is empty (all zeros). " - "SeedVR2 is an upscaler; connect an encoded latent from " - "'Apply SeedVR2 conditioning' rather than an empty latent." - ) - samples_4d = comfy.sample.fix_empty_latent_channels( - model, samples_4d, - latent.get("downscale_ratio_spacial", None), - ) - if samples_4d.ndim != 4: - raise ValueError( - f"SeedVR2ProgressiveSampler: expected 4D collapsed latent " - f"(B, 16*T, H, W); got shape {tuple(samples_4d.shape)}." - ) - B, CT, H, W = samples_4d.shape - if CT % SEEDVR2_LATENT_CHANNELS != 0: - raise ValueError( - f"SeedVR2ProgressiveSampler: collapsed channel dim {CT} is " - f"not divisible by SeedVR2 latent channels " - f"{SEEDVR2_LATENT_CHANNELS}; latent does not appear to be " - f"SeedVR2-shaped." - ) - T_latent = CT // SEEDVR2_LATENT_CHANNELS - T_pixel = 4 * (T_latent - 1) + 1 - - if chunking_mode not in ("manual", "auto"): - raise ValueError( - f"SeedVR2ProgressiveSampler: chunking_mode must be " - f"'manual' or 'auto'; got {chunking_mode!r}." - ) - - if chunking_mode == "auto": - free_memory = comfy.model_management.get_free_memory(model.load_device) - seed_frames_per_chunk = _seedvr2_vram_seed_frames_per_chunk( - free_memory, T_pixel, - ) - logging.info( - "SeedVR2ProgressiveSampler auto: free=%.2fGB -> seeding " - "frames_per_chunk=%s (4n+1; T_pixel=%s).", - free_memory / (1024 ** 3), seed_frames_per_chunk, T_pixel, - ) - attempts = _seedvr2_auto_chunk_attempts( - T_latent, T_pixel, seed_frames_per_chunk, - ) - for i, attempt_frames_per_chunk in enumerate(attempts): - retry = False - try: - return cls.execute( - model=model, seed=seed, steps=steps, cfg=cfg, - sampler_name=sampler_name, scheduler=scheduler, - positive=positive, negative=negative, - latent=latent, denoise=denoise, - frames_per_chunk=attempt_frames_per_chunk, - temporal_overlap=temporal_overlap, - chunking_mode="manual", - ) - except Exception as e: - comfy.model_management.raise_non_oom(e) - if i == len(attempts) - 1: - raise RuntimeError( - "SeedVR2ProgressiveSampler: exhausted auto " - "chunking attempts after OOM. Tried " - f"frames_per_chunk values {attempts}." - ) from e - retry = True - - if retry: - logging.warning( - "SeedVR2ProgressiveSampler auto chunking OOM at " - "frames_per_chunk=%s; retrying with " - "frames_per_chunk=%s.", - attempt_frames_per_chunk, attempts[i + 1], - ) - - # Short-circuit: total fits in one chunk -> standard path with no - # chunking overhead. Output of this branch is byte-identical to the - # built-in KSampler given the same (model, seed, steps, cfg, - # sampler_name, scheduler, positive, negative, latent, - # denoise) tuple. - if T_pixel <= frames_per_chunk: - return io.NodeOutput(_run_standard_sample( - model, seed, steps, cfg, sampler_name, scheduler, - positive, negative, latent, denoise, - )) - - # Map pixel chunk -> latent chunk. Each chunk's latent length is - # at most ``chunk_latent``; the final chunk may be a runt that - # is automatically 4n+1-aligned in the pixel domain by the - # T_pixel = 4*(T_latent-1) + 1 mapping (every positive integer - # T_latent corresponds to a valid 4n+1 pixel count). - chunk_latent = (frames_per_chunk - 1) // 4 + 1 - - # ``temporal_overlap`` is exposed in latent-frame units, but users - # do not know the derived latent chunk length. Treat oversized - # values as "maximum valid overlap" while preserving a strictly - # positive chunk-loop stride. - if temporal_overlap < 0: - raise ValueError( - f"SeedVR2ProgressiveSampler: temporal_overlap must be >= 0; " - f"got {temporal_overlap}." - ) - temporal_overlap = min(temporal_overlap, chunk_latent - 1) - step_latent = chunk_latent - temporal_overlap - - # Generate full noise once from the user seed, then slice along T - # per chunk. Using one global noise tensor (rather than re-seeding - # per chunk) preserves seed-determinism across chunk-count - # variations: the same (seed, total T_latent) always produces the - # same noise samples regardless of how the work is partitioned. - batch_inds = latent.get("batch_index", None) - noise_full = comfy.sample.prepare_noise(samples_4d, seed, batch_inds) - - noise_mask = latent.get("noise_mask", None) - - # Build the flat list of chunk ranges first so the chunking - # geometry is fully known before any sample call. - chunk_ranges = [] - for chunk_start in range(0, T_latent, step_latent): - chunk_end = min(chunk_start + chunk_latent, T_latent) - if chunk_start >= chunk_end: - # The final iteration of a stride that lands exactly on - # T_latent produces a zero-length chunk; skip it. - break - chunk_ranges.append((chunk_start, chunk_end)) - if chunk_end >= T_latent: - break - - def _sample_one_chunk(chunk_start, chunk_end): - samples_chunk = _slice_collapsed_4d_along_t( - samples_4d, chunk_start, chunk_end, - SEEDVR2_LATENT_CHANNELS, - ) - noise_chunk = _slice_collapsed_4d_along_t( - noise_full, chunk_start, chunk_end, - SEEDVR2_LATENT_CHANNELS, - ) - positive_chunk = _slice_seedvr2_cond_along_t( - positive, chunk_start, chunk_end, - ) - negative_chunk = _slice_seedvr2_cond_along_t( - negative, chunk_start, chunk_end, - ) - - # Per-chunk noise_mask handling: standard masks are passed - # through for KSampler expansion; pre-expanded collapsed - # masks are sliced. - chunk_noise_mask = None - if noise_mask is not None: - chunk_noise_mask = _slice_seedvr2_noise_mask_along_t( - noise_mask, samples_4d, chunk_start, chunk_end, - ) - - return comfy.sample.sample( - model, noise_chunk, steps, cfg, sampler_name, scheduler, - positive_chunk, negative_chunk, samples_chunk, - denoise=denoise, noise_mask=chunk_noise_mask, seed=seed, - ) - - chunk_specs = [] - for chunk_start, chunk_end in chunk_ranges: - chunk_samples = _sample_one_chunk(chunk_start, chunk_end) - chunk_specs.append((chunk_start, chunk_end, chunk_samples)) - - final = _concat_chunks_with_overlap_blend( - chunk_specs, SEEDVR2_LATENT_CHANNELS, temporal_overlap, - ) - - out = latent.copy() - out.pop("downscale_ratio_spacial", None) - out["samples"] = final - return io.NodeOutput(out) - - class SeedVRExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -952,7 +439,6 @@ class SeedVRExtension(ComfyExtension): SeedVR2Conditioning, SeedVR2Preprocess, SeedVR2PostProcessing, - SeedVR2ProgressiveSampler, ] async def comfy_entrypoint() -> SeedVRExtension: diff --git a/tests-unit/comfy_extras_test/test_seedvr2_nodes.py b/tests-unit/comfy_extras_test/test_seedvr2_nodes.py index f7d9a4f65..1c5d20ac9 100644 --- a/tests-unit/comfy_extras_test/test_seedvr2_nodes.py +++ b/tests-unit/comfy_extras_test/test_seedvr2_nodes.py @@ -31,7 +31,7 @@ def test_seedvr_node_signature_matches_schema(): sys.modules.pop("comfy_extras.nodes_seedvr", None) try: nodes_seedvr = importlib.import_module("comfy_extras.nodes_seedvr") - for node_cls in (nodes_seedvr.SeedVR2Preprocess, nodes_seedvr.SeedVR2PostProcessing, nodes_seedvr.SeedVR2Conditioning, nodes_seedvr.SeedVR2ProgressiveSampler): + for node_cls in (nodes_seedvr.SeedVR2Preprocess, nodes_seedvr.SeedVR2PostProcessing, nodes_seedvr.SeedVR2Conditioning): schema_ids = [i.id for i in node_cls.define_schema().inputs] exec_params = [ p for p in inspect.signature(node_cls.execute).parameters.keys() diff --git a/tests-unit/comfy_test/test_seedvr_progressive_sampler.py b/tests-unit/comfy_test/test_seedvr_progressive_sampler.py deleted file mode 100644 index 146b81225..000000000 --- a/tests-unit/comfy_test/test_seedvr_progressive_sampler.py +++ /dev/null @@ -1,95 +0,0 @@ -"""Unit tests for ``comfy_extras.nodes_seedvr.SeedVR2ProgressiveSampler``.""" - -from unittest.mock import patch - -import pytest -import torch - -from comfy.cli_args import args as cli_args - -if not torch.cuda.is_available(): - cli_args.cpu = True - -import comfy.sample # noqa: E402 -import comfy_extras.nodes_seedvr as nodes_seedvr_mod # noqa: E402 -from comfy_extras.nodes_seedvr import SeedVR2ProgressiveSampler # noqa: E402 - -_LAT_C = 16 -_COND_C = 17 - - -def _make_inputs(B: int = 1, T: int = 5, H: int = 8, W: int = 8): - """Build minimal SeedVR2-shaped sampling inputs.""" - samples_5d = torch.arange( - B * _LAT_C * T * H * W, dtype=torch.float32 - ).reshape(B, _LAT_C, T, H, W) - samples = samples_5d.reshape(B, _LAT_C * T, H, W).contiguous() - - cond_5d = torch.arange( - B * _COND_C * T * H * W, dtype=torch.float32 - ).reshape(B, _COND_C, T, H, W) + 10000.0 - cond = cond_5d.reshape(B, _COND_C * T, H, W).contiguous() - - text_pos = torch.zeros(1, 4, 32) - text_neg = torch.zeros(1, 4, 32) - positive = [[text_pos, {"condition": cond.clone()}]] - negative = [[text_neg, {"condition": cond.clone()}]] - latent_image = {"samples": samples} - return latent_image, positive, negative, samples_5d, cond_5d - - -def _identity_fix_empty(model, latent_image, downscale_ratio_spacial=None): - return latent_image - - -def _fingerprinted_prepare_noise(latent_image, seed, batch_inds=None): - """Return a tensor whose values encode ``(seed, position)``.""" - base = torch.arange( - latent_image.numel(), dtype=torch.float32 - ).reshape(latent_image.shape) - return base + float(seed) * 1e6 - - -def test_progressive_sampler_schema_exposes_manual_default_auto_chunking(): - schema = SeedVR2ProgressiveSampler.define_schema() - inputs = {item.id: item for item in schema.inputs} - - assert inputs["chunking_mode"].options == ["manual", "auto"] - assert inputs["chunking_mode"].default == "manual" - - -def test_vram_seed_frames_per_chunk_predicts_4n1_clamped_to_t_pixel(): - """VRAM chunk-size law: seed = nearest 4n+1 to 4*(free_GB - 3), clamped to [1, t_pixel].""" - gib = 1024 ** 3 - seed = nodes_seedvr_mod._seedvr2_vram_seed_frames_per_chunk - assert seed(20 * gib, 65) == 65 # 4*(20-3)=68 -> 4n+1 69 -> clamp to t_pixel 65 - assert seed(6 * gib, 97) == 13 # 4*(6-3)=12 -> nearest 4n+1 13 - assert seed(2 * gib, 97) == 1 # below margin -> floor at 1 - - -@pytest.mark.parametrize("bad_chunk", [0, -1, 2]) -def test_t3_invalid_frames_per_chunk_raises_value_error(bad_chunk): - """``frames_per_chunk`` violating 4n+1 (or <1) must raise ``ValueError`` before any model invocation.""" - latent, pos, neg, _, _ = _make_inputs(T=5) - - sampler_called = {"n": 0} - - def _should_not_be_called(*args, **kwargs): - sampler_called["n"] += 1 - return torch.zeros(1) - - with patch.object(comfy.sample, "sample", - side_effect=_should_not_be_called), \ - patch.object(comfy.sample, "fix_empty_latent_channels", - side_effect=_identity_fix_empty), \ - patch.object(comfy.sample, "prepare_noise", - side_effect=_fingerprinted_prepare_noise): - with pytest.raises(ValueError) as excinfo: - SeedVR2ProgressiveSampler.execute( - model=None, seed=0, steps=2, cfg=1.0, - sampler_name="euler", scheduler="simple", - positive=pos, negative=neg, latent=latent, - denoise=1.0, frames_per_chunk=bad_chunk, temporal_overlap=0, - ) - assert str(bad_chunk) in str(excinfo.value) - assert sampler_called["n"] == 0 From c7b2c3b56955fd378a7099f54ca94c6d909bede2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 2 Jul 2026 22:59:38 -0400 Subject: [PATCH 11/12] Refactors and cleanups. --- comfy/latent_formats.py | 1 + comfy/ldm/seedvr/attention.py | 25 +- comfy/ldm/seedvr/color_fix.py | 80 ++--- comfy/ldm/seedvr/constants.py | 47 +-- comfy/ldm/seedvr/model.py | 213 ++++++------- comfy/ldm/seedvr/vae.py | 295 ++++++------------ comfy/model_base.py | 3 +- comfy/model_detection.py | 35 ++- comfy/sd.py | 46 ++- comfy/supported_models.py | 1 + comfy_extras/nodes_seedvr.py | 84 ++--- .../test_seedvr2_conditioning.py | 49 +-- tests-unit/comfy_test/model_detection_test.py | 11 + .../comfy_test/seedvr_vae_forward_test.py | 34 +- tests-unit/comfy_test/test_seedvr2_dtype.py | 5 +- .../comfy_test/test_seedvr2_internals.py | 55 +--- tests-unit/comfy_test/test_seedvr2_model.py | 98 ++---- .../comfy_test/test_seedvr2_vae_decode.py | 11 +- .../comfy_test/test_seedvr2_vae_tiled.py | 72 +++-- 19 files changed, 436 insertions(+), 729 deletions(-) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index fc5b13c21..8a16cfe55 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -781,6 +781,7 @@ class ACEAudio(LatentFormat): class SeedVR2(LatentFormat): latent_channels = 16 + latent_dimensions = 3 class ACEAudio15(LatentFormat): latent_channels = 64 diff --git a/comfy/ldm/seedvr/attention.py b/comfy/ldm/seedvr/attention.py index 5d4054ab9..11b4c1e4a 100644 --- a/comfy/ldm/seedvr/attention.py +++ b/comfy/ldm/seedvr/attention.py @@ -22,33 +22,14 @@ def _var_attention_output(out, heads, head_dim, skip_output_reshape): return out.reshape(-1, heads * head_dim) -def _validate_split_cu_seqlens(name, cu_seqlens, token_count): - if cu_seqlens.dtype not in (torch.int32, torch.int64): - raise ValueError(f"{name} must use an integer dtype") - if cu_seqlens.ndim != 1 or cu_seqlens.numel() < 2: - raise ValueError(f"{name} must be a 1D tensor with at least two offsets") - if cu_seqlens[0].item() != 0: - raise ValueError(f"{name} must start at 0") - if (cu_seqlens[1:] <= cu_seqlens[:-1]).any().item(): - raise ValueError(f"{name} must be strictly increasing") - if cu_seqlens[-1].item() != token_count: - raise ValueError(f"{name} does not match token count") - - -def _split_indices(cu_seqlens): - return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long) - - def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs): q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape) - _validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0]) - _validate_split_cu_seqlens("cu_seqlens_k", cu_seqlens_k, k.shape[0]) - if cu_seqlens_k[-1].item() != v.shape[0]: + q_split_indices = cu_seqlens_q[1:-1] + k_split_indices = cu_seqlens_k[1:-1] + if k.shape[0] != v.shape[0]: raise ValueError("cu_seqlens_k does not match v token count") - q_split_indices = _split_indices(cu_seqlens_q) - k_split_indices = _split_indices(cu_seqlens_k) q_splits = torch.tensor_split(q, q_split_indices, dim=0) k_splits = torch.tensor_split(k, k_split_indices, dim=0) v_splits = torch.tensor_split(v, k_split_indices, dim=0) diff --git a/comfy/ldm/seedvr/color_fix.py b/comfy/ldm/seedvr/color_fix.py index 440b3d26c..a43cb5270 100644 --- a/comfy/ldm/seedvr/color_fix.py +++ b/comfy/ldm/seedvr/color_fix.py @@ -45,7 +45,6 @@ def wavelet_decomposition(image: Tensor, levels: int = WAVELET_DECOMP_LEVELS): def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor: if content_feat.shape != style_feat.shape: - # Resize style to match content spatial dimensions if len(content_feat.shape) >= 3: style_feat = F.interpolate( style_feat, @@ -54,12 +53,11 @@ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor: align_corners=False ) - # Decompose both features into frequency components content_high_freq, content_low_freq = wavelet_decomposition(content_feat) - del content_low_freq # Free memory immediately + del content_low_freq style_high_freq, style_low_freq = wavelet_decomposition(style_feat) - del style_high_freq # Free memory immediately + del style_high_freq if content_high_freq.shape != style_low_freq.shape: style_low_freq = F.interpolate( @@ -73,27 +71,23 @@ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor: return content_high_freq.clamp_(-1.0, 1.0) -def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor: +def _histogram_matching_channel(source: Tensor, reference: Tensor) -> Tensor: original_shape = source.shape - # Flatten source_flat = source.flatten() reference_flat = reference.flatten() - # Sort both arrays source_sorted, source_indices = torch.sort(source_flat) reference_sorted, _ = torch.sort(reference_flat) del reference_flat - # Quantile mapping n_source = len(source_sorted) n_reference = len(reference_sorted) if n_source == n_reference: matched_sorted = reference_sorted else: - # Interpolate reference to match source quantiles - source_quantiles = torch.linspace(0, 1, n_source, device=device) + source_quantiles = torch.linspace(0, 1, n_source, device=source.device) ref_indices = (source_quantiles * (n_reference - 1)).long() ref_indices.clamp_(0, n_reference - 1) matched_sorted = reference_sorted[ref_indices] @@ -101,7 +95,6 @@ def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch del source_sorted, source_flat - # Reconstruct using argsort (portable across CUDA/ROCm/MPS) inverse_indices = torch.argsort(source_indices) del source_indices matched_flat = matched_sorted[inverse_indices] @@ -109,17 +102,14 @@ def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch return matched_flat.reshape(original_shape) -def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor: - """Convert batch of CIELAB images to RGB color space.""" +def _lab_to_rgb_batch(lab: Tensor, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor: L, a, b = lab[:, 0], lab[:, 1], lab[:, 2] - # LAB to XYZ fy = (L + 16.0) / 116.0 fx = a.div(500.0).add_(fy) fz = fy - b / 200.0 del L, a, b - # XYZ transformation x = torch.where( fx > epsilon, torch.pow(fx, 3.0), @@ -137,20 +127,16 @@ def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, eps ) del fx, fy, fz - # Apply D65 white point (in-place) x.mul_(D65_WHITE_X) - # y *= 1.00000 # (no-op, skip) z.mul_(D65_WHITE_Z) xyz = torch.stack([x, y, z], dim=1) del x, y, z - # Matrix multiplication: XYZ -> RGB - B, C, H, W = xyz.shape + B, _, H, W = xyz.shape xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3) del xyz - # Ensure dtype consistency for matrix multiplication xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype) rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T) del xyz_flat @@ -158,7 +144,6 @@ def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, eps rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2) del rgb_linear_flat - # Apply inverse gamma correction (delinearize) mask = rgb_linear > 0.0031308 rgb = torch.where( mask, @@ -169,9 +154,7 @@ def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, eps return torch.clamp(rgb, 0.0, 1.0) -def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor: - """Convert batch of RGB images to CIELAB color space using D65 illuminant.""" - # Apply sRGB gamma correction (linearize) +def _rgb_to_lab_batch(rgb: Tensor, matrix: Tensor, epsilon: float, kappa: float) -> Tensor: mask = rgb > 0.04045 rgb_linear = torch.where( mask, @@ -180,12 +163,10 @@ def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon ) del mask - # Matrix multiplication: RGB -> XYZ - B, C, H, W = rgb_linear.shape + B, _, H, W = rgb_linear.shape rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3) del rgb_linear - # Ensure dtype consistency for matrix multiplication rgb_flat = rgb_flat.to(dtype=matrix.dtype) xyz_flat = torch.matmul(rgb_flat, matrix.T) del rgb_flat @@ -193,12 +174,9 @@ def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2) del xyz_flat - # Normalize by D65 white point (in-place) - xyz[:, 0].div_(D65_WHITE_X) # X - # xyz[:, 1] /= 1.00000 # Y (no-op, skip) - xyz[:, 2].div_(D65_WHITE_Z) # Z + xyz[:, 0].div_(D65_WHITE_X) + xyz[:, 2].div_(D65_WHITE_Z) - # XYZ to LAB transformation epsilon_cubed = epsilon ** 3 mask = xyz > epsilon_cubed f_xyz = torch.where( @@ -208,10 +186,9 @@ def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon ) del xyz, mask - # Extract channels and compute LAB - L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100] - a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127] - b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127] + L = f_xyz[:, 1].mul(116.0).sub_(16.0) + a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) + b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) del f_xyz return torch.stack([L, a, b], dim=1) @@ -232,13 +209,9 @@ def lab_color_transfer( ) device = content_feat.device - - def ensure_float32_precision(c): - orig_dtype = c.dtype - c = c.float() - return c, orig_dtype - content_feat, original_dtype = ensure_float32_precision(content_feat) - style_feat, _ = ensure_float32_precision(style_feat) + original_dtype = content_feat.dtype + content_feat = content_feat.float() + style_feat = style_feat.float() rgb_to_xyz_matrix = torch.tensor([ [0.4124564, 0.3575761, 0.1804375], @@ -258,39 +231,30 @@ def lab_color_transfer( content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0) style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0) - # Convert to LAB color space - content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa) + content_lab = _rgb_to_lab_batch(content_feat, rgb_to_xyz_matrix, epsilon, kappa) del content_feat - style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa) + style_lab = _rgb_to_lab_batch(style_feat, rgb_to_xyz_matrix, epsilon, kappa) del style_feat, rgb_to_xyz_matrix - # Match chrominance channels (a*, b*) for accurate color transfer - matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device) - matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device) + matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1]) + matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2]) - # Handle luminance with weighted blending if luminance_weight < 1.0: - # Partially match luminance for better overall color accuracy - matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0], device) - # Blend: preserve some content L* for detail, adopt some style L* for color + matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0]) result_L = content_lab[:, 0].mul(luminance_weight).add_(matched_L.mul(1.0 - luminance_weight)) del matched_L else: - # Fully preserve content luminance result_L = content_lab[:, 0] del content_lab, style_lab - # Reconstruct LAB with corrected channels result_lab = torch.stack([result_L, matched_a, matched_b], dim=1) del result_L, matched_a, matched_b - # Convert back to RGB - result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa) + result_rgb = _lab_to_rgb_batch(result_lab, xyz_to_rgb_matrix, epsilon, kappa) del result_lab, xyz_to_rgb_matrix - # Convert back to [-1, 1] range (in-place) result = result_rgb.mul_(2.0).sub_(1.0) del result_rgb diff --git a/comfy/ldm/seedvr/constants.py b/comfy/ldm/seedvr/constants.py index df91c7772..b8b300388 100644 --- a/comfy/ldm/seedvr/constants.py +++ b/comfy/ldm/seedvr/constants.py @@ -1,34 +1,21 @@ -"""Named constants for the SeedVR2 integration, grouped by provenance. +"""SeedVR2 constants.""" -Provenance prefixes: -- ``SEEDVR2_*`` - introduced by this integration (no external origin); rationale inline. -- ``BYTEDANCE_*`` - ported from the official ByteDance-Seed/SeedVR release; each cites - the upstream config/source path it was lifted from. -- unprefixed standards (``ROPE_THETA``, ``CIELAB_*``, ``D65_*``) - published literature / - ISO / CIE values; cite the standard. -""" - -SEEDVR2_7B_VID_DIM = 3072 # runtime 3b-vs-7b sentinel; tested against vid_dim. - # (3072 is ByteDance's 7b vid_dim; the sentinel use is ours.) -SEEDVR2_OOM_BACKOFF_DIVISOR = 2 # OOM retry backoff: halve the chunk and retry. -SEEDVR2_DTYPE_BYTES_FLOOR = 4 # per-element byte floor for memory math (fp32 worst case). -SEEDVR2_7B_MLP_CHUNK = 8192 # 7b MLP token-chunk to bound peak VRAM. +SEEDVR2_7B_VID_DIM = 3072 +SEEDVR2_OOM_BACKOFF_DIVISOR = 2 +SEEDVR2_DTYPE_BYTES_FLOOR = 4 +SEEDVR2_7B_MLP_CHUNK = 8192 SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS = 4096 # partial-RoPE application token-chunk. -SEEDVR2_LATENT_CHANNELS = 16 # SeedVR2 latent channel count (== BYTEDANCE latent_channels). +SEEDVR2_LATENT_CHANNELS = 16 -# Color-correction memory model (fork tuning; per-frame VRAM estimate for chunk sizing) -SEEDVR2_COLOR_MEM_HEADROOM = 0.75 # fraction of free VRAM usable per color-correction chunk. -SEEDVR2_LAB_SCALE_MULTIPLIER = 13 # per-frame byte multiplier, LAB path. +SEEDVR2_COLOR_MEM_HEADROOM = 0.75 +SEEDVR2_LAB_SCALE_MULTIPLIER = 13 SEEDVR2_WAVELET_SCALE_MULTIPLIER = 10 # per-frame byte multiplier, wavelet path. -SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 # per-frame byte multiplier, AdaIN path. +SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 -# -------------------------------------------------------------------------------------- -# ByteDance config / source (BYTEDANCE - cite ByteDance-Seed/SeedVR) -# -------------------------------------------------------------------------------------- -BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57 (scaling_factor); latent denorm. -BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 # infer.py (shifting_factor default); latent denorm shift. -BYTEDANCE_VAE_CONV_MEM_GIB = 0.5 # configs_3b/main.yaml:54 (conv_max_mem). -BYTEDANCE_VAE_NORM_MEM_GIB = 0.5 # configs_3b/main.yaml:55 (norm_max_mem). +BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57. +BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 +BYTEDANCE_VAE_CONV_MEM_GIB = 0.5 +BYTEDANCE_VAE_NORM_MEM_GIB = 0.5 BYTEDANCE_LOGVAR_CLAMP_MIN = -30.0 # video_vae_v3/modules/types.py:28. BYTEDANCE_LOGVAR_CLAMP_MAX = 20.0 # video_vae_v3/modules/types.py:28. BYTEDANCE_GN_CHUNKS_FP16 = 4 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp16). @@ -42,18 +29,10 @@ BYTEDANCE_MAX_TEMPORAL_WINDOW = 30 # dit_v2/window.py:35 (max temporal wind BYTEDANCE_ROPE_MAX_FREQ = 256 # dit_v2/rope.py:31 (pixel-RoPE max frequency). BYTEDANCE_SINUSOIDAL_DIM = 256 # dit_3b/nadit.py:120 (timestep sinusoidal embed dim). -# -------------------------------------------------------------------------------------- -# Published standards (cite the literature) -# -------------------------------------------------------------------------------------- ROPE_THETA = 10000 # RoPE base; Su et al., "RoFormer", arXiv:2104.09864. -# CIELAB f(t) piecewise constants and D65 white point (CIE 15 colorimetry; CIE D65). CIELAB_DELTA = 6.0 / 29.0 # CIE 15 (delta). CIELAB_KAPPA = (29.0 / 3.0) ** 3 # CIE 15 (kappa). D65_WHITE_X = 0.95047 # CIE D65 standard illuminant Xn (Yn = 1). D65_WHITE_Z = 1.08883 # CIE D65 standard illuminant Zn. WAVELET_DECOMP_LEVELS = 5 # wavelet color-fix decomposition depth (GIMP/Krita; StableSR). - -# NOTE: the sRGB<->XYZ D65 3x3 matrices (IEC 61966-2-1) remain inline in the color code and -# are named (SRGB_TO_XYZ_D65 / XYZ_TO_SRGB_D65) during the color-module extraction, where the -# exact existing coefficients move verbatim rather than being retyped here. diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index ee50449a4..872140558 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -3,7 +3,7 @@ from typing import Optional, Tuple, Union, List, Dict, Any, Callable import torch.nn.functional as F from math import ceil, pi import torch -from itertools import chain +from itertools import accumulate, chain from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding from comfy.ldm.seedvr.attention import optimized_var_attention from torch.nn.modules.utils import _triple @@ -18,6 +18,7 @@ from comfy.ldm.seedvr.constants import ( ROPE_THETA, SEEDVR2_7B_MLP_CHUNK, SEEDVR2_7B_VID_DIM, + SEEDVR2_LATENT_CHANNELS, SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS, ) import comfy.model_management @@ -70,7 +71,7 @@ def repeat_concat_idx( vid_idx = torch.arange(vid_len.sum(), device=device) txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) txt_repeat_list = txt_repeat.tolist() - tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat) + tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat_list) src_idx = torch.argsort(tgt_idx) txt_idx_len = len(tgt_idx) - len(vid_idx) repeat_txt_len = (txt_len * txt_repeat).tolist() @@ -88,6 +89,9 @@ def repeat_concat_idx( lambda all: unconcat_coalesce(all), ) +def cumulative_lengths(lengths): + return [0, *accumulate(lengths)] + @dataclass class MMArg: @@ -110,16 +114,14 @@ def get_window_op(name: str): raise ValueError(f"Unknown windowing method: {name}") -# -------------------------------- Windowing -------------------------------- # def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): t, h, w = size resized_nt, resized_nh, resized_nw = num_windows - #cal windows under 720p scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w)) resized_h, resized_w = round(h * scale), round(w * scale) - wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. - wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) # window size. - nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww) # window size. + wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) + wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) + nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww) return [ ( slice(it * wt, min((it + 1) * wt, t)), @@ -137,19 +139,18 @@ def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): t, h, w = size resized_nt, resized_nh, resized_nw = num_windows - #cal windows under 720p scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w)) resized_h, resized_w = round(h * scale), round(w * scale) - wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. - wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) # window size. + wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) + wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) - st, sh, sw = ( # shift size. + st, sh, sw = ( 0.5 if wt < t else 0, 0.5 if wh < h else 0, 0.5 if ww < w else 0, ) - nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) # window size. - nt, nh, nw = ( # number of window. + nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) + nt, nh, nw = ( nt + 1 if st > 0 else 1, nh + 1 if sh > 0 else 1, nw + 1 if sw > 0 else 1, @@ -175,7 +176,6 @@ class RotaryEmbedding(nn.Module): freqs_for = 'lang', theta = 10000, max_freq = 10, - learned_freq = False, ): super().__init__() @@ -185,18 +185,14 @@ class RotaryEmbedding(nn.Module): freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) elif freqs_for == 'pixel': freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi + else: + raise ValueError(f"Unknown rotary frequency type: {freqs_for}") - self.freqs = nn.Parameter(freqs, requires_grad = learned_freq) - - self.learned_freq = learned_freq - - # dummy for device - - self.register_buffer('dummy', torch.tensor(0), persistent = False) + self.register_buffer("freqs", freqs) @property def device(self): - return self.dummy.device + return self.freqs.device def get_axial_freqs( self, @@ -206,10 +202,9 @@ class RotaryEmbedding(nn.Module): Colon = slice(None) all_freqs = [] - # handle offset - if exists(offsets): - assert len(offsets) == len(dims) + if len(offsets) != len(dims): + raise ValueError(f"SeedVR2 rotary offsets length must match dims length, got {len(offsets)} and {len(dims)}.") for ind, dim in enumerate(dims): @@ -224,7 +219,7 @@ class RotaryEmbedding(nn.Module): pos = pos + offset - freqs = self.forward(pos, seq_len = dim) + freqs = self.forward(pos) all_axis = [None] * len(dims) all_axis[ind] = Colon @@ -232,16 +227,12 @@ class RotaryEmbedding(nn.Module): new_axis_slice = (Ellipsis, *all_axis, Colon) all_freqs.append(freqs[new_axis_slice]) - # concat all freqs - all_freqs = torch.broadcast_tensors(*all_freqs) return torch.cat(all_freqs, dim = -1) def forward( self, t, - seq_len: int | None = None, - offset = 0 ): freqs = self.freqs @@ -258,9 +249,6 @@ class RotaryEmbeddingBase(nn.Module): freqs_for="pixel", max_freq=BYTEDANCE_ROPE_MAX_FREQ, ) - freqs = self.rope.freqs - del self.rope.freqs - self.rope.register_buffer("freqs", freqs.detach()) def get_axial_freqs(self, *dims): return self.rope.get_axial_freqs(*dims) @@ -306,7 +294,7 @@ class NaRotaryEmbedding3d(RotaryEmbedding3d): freqs_for="pixel", max_freq=BYTEDANCE_ROPE_MAX_FREQ, ) - plain_rope = plain_rope.to(self.rope.dummy.device) + plain_rope = plain_rope.to(self.rope.device) freq_list = [] for f, h, w in shape.tolist(): freqs = plain_rope.get_axial_freqs(f, h, w) @@ -322,9 +310,6 @@ class MMRotaryEmbeddingBase(RotaryEmbeddingBase): freqs_for="lang", theta=ROPE_THETA, ) - freqs = self.rope.freqs - del self.rope.freqs - self.rope.register_buffer("freqs", freqs.detach()) self.mm = True def slice_at_dim(t, dim_slice: slice, *, dim): @@ -333,8 +318,6 @@ def slice_at_dim(t, dim_slice: slice, *, dim): colons[dim] = dim_slice return t[tuple(colons)] -# rotary embedding helper functions - def rotate_half(x): x = x.reshape(*x.shape[:-1], x.shape[-1] // 2, 2) x1, x2 = x.unbind(dim = -1) @@ -373,7 +356,6 @@ def _apply_seedvr2_rotary_emb( return torch.cat((t_left, t_middle, t_right), dim=-1).to(dtype) def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor: - """Convert lucidrains-interleaved freqs to flux-canonical fp32 freqs_cis `[..., d/2, 2, 2]` (cos/-sin/sin/cos), per `comfy/ldm/flux/math.py:rope`.""" angles = freqs_interleaved[..., ::2].float() cos = torch.cos(angles) sin = torch.sin(angles) @@ -382,12 +364,6 @@ def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor: def _apply_rope1_partial(t: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: - """Rotate the leading ``rot_d = 2 * freqs_cis.shape[-3]`` dims of ``t`` and pass the rest - through; in-place for inference, cloned for training (autograd). Mirrors the legacy - ``apply_rotary_emb`` ``t_left``/``t_middle``/``t_right`` split: 3B ``rope_dim=128`` gives - ``42*3 = 126`` rotated of head_dim 128 (trailing 2 unrotated). Fast path skips the cat when - ``rot_d == t.shape[-1]``. - """ out = t.clone() if t.requires_grad or comfy.model_management.in_training else t rot_d = 2 * freqs_cis.shape[-3] seq_len = out.shape[-2] @@ -454,14 +430,13 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): torch.Tensor, ]: - # Calculate actual max dimensions needed for this batch max_temporal = 0 max_height = 0 max_width = 0 max_txt_len = 0 for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): - max_temporal = max(max_temporal, l + f) # Need up to l+f for temporal + max_temporal = max(max_temporal, l + f) max_height = max(max_height, h) max_width = max(max_width, w) max_txt_len = max(max_txt_len, l) @@ -475,7 +450,6 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): ).float() txt_freqs = self.get_axial_freqs(max_txt_len + 16) - # Now slice as before vid_freq_list, txt_freq_list = [], [] for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): vid_freq = vid_freqs[l : l + f, :h, :w].reshape(-1, vid_freqs.size(-1)) @@ -485,13 +459,6 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): vid_freqs_interleaved = torch.cat(vid_freq_list, dim=0) txt_freqs_interleaved = torch.cat(txt_freq_list, dim=0) - # Convert from lucidrains-interleaved layout `[θ0, θ0, θ1, θ1, ...]` - # (produced by `repeat(freqs, '... n -> ... (n r)', r=2)` in the - # upstream `RotaryEmbedding.forward`) to flux-canonical `freqs_cis` - # in shape `[..., d/2, 2, 2]` with `cos/-sin/sin/cos` baked in. - # Mirrors `comfy/ldm/flux/math.py:rope` (line 27) so the trailing - # 2x2 is the per-frequency rotation matrix that - # `comfy.ldm.flux.math.apply_rope1` expects. return _to_flux_freqs_cis(vid_freqs_interleaved), _to_flux_freqs_cis(txt_freqs_interleaved) class MMModule(nn.Module): @@ -507,8 +474,10 @@ class MMModule(nn.Module): self.shared_weights = shared_weights self.vid_only = vid_only if self.shared_weights: - assert get_args("vid", args) == get_args("txt", args) - assert get_kwargs("vid", kwargs) == get_kwargs("txt", kwargs) + if get_args("vid", args) != get_args("txt", args): + raise ValueError("SeedVR2 shared MMModule requires matching vid/txt args.") + if get_kwargs("vid", kwargs) != get_kwargs("txt", kwargs): + raise ValueError("SeedVR2 shared MMModule requires matching vid/txt kwargs.") self.all = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) else: self.vid = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) @@ -543,6 +512,7 @@ def get_na_rope(rope_type: Optional[str], dim: int): return NaRotaryEmbedding3d(dim=dim) if rope_type == "mmrope3d": return NaMMRotaryEmbedding3d(dim=dim) + raise ValueError(f"Unknown SeedVR2 rope type: {rope_type}") class NaMMAttention(nn.Module): def __init__( @@ -558,7 +528,6 @@ class NaMMAttention(nn.Module): rope_dim: int, shared_weights: bool, device, dtype, operations, - **kwargs, ): super().__init__() dim = MMArg(vid_dim, txt_dim) @@ -597,16 +566,19 @@ def window( ): hid = unflatten(hid, hid_shape) hid = list(map(window_fn, hid)) - hid_windows = torch.as_tensor([len(x) for x in hid], device=hid_shape.device) - hid, hid_shape = flatten(list(chain(*hid))) - return hid, hid_shape, hid_windows + hid_windows_list = [len(x) for x in hid] + hid_windows = torch.as_tensor(hid_windows_list, device=hid_shape.device) + hid = list(chain(*hid)) + hid_len_list = [math.prod(x.shape[:-1]) for x in hid] + hid, hid_shape = flatten(hid) + return hid, hid_shape, hid_windows, hid_len_list, hid_windows_list def window_idx( hid_shape: torch.LongTensor, # (b n) window_fn: Callable[[torch.Tensor], List[torch.Tensor]], ): hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1) - tgt_idx, tgt_shape, tgt_windows = window(hid_idx, hid_shape, window_fn) + tgt_idx, tgt_shape, tgt_windows, tgt_len_list, tgt_windows_list = window(hid_idx, hid_shape, window_fn) tgt_idx = tgt_idx.squeeze(-1) src_idx = torch.argsort(tgt_idx) return ( @@ -614,6 +586,8 @@ def window_idx( lambda hid: torch.index_select(hid, 0, src_idx), tgt_shape, tgt_windows, + tgt_len_list, + tgt_windows_list, ) class NaSwinAttention(NaMMAttention): @@ -622,13 +596,15 @@ class NaSwinAttention(NaMMAttention): *args, window: Union[int, Tuple[int, int, int]], window_method: str, + version: bool = False, **kwargs, ): super().__init__(*args, **kwargs) - self.version_7b = kwargs.get("version", False) + self.version_7b = version self.window = _triple(window) self.window_method = window_method - assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window)) + if not all(isinstance(v, int) and v >= 0 for v in self.window): + raise ValueError(f"SeedVR2 window must contain non-negative integers, got {self.window}.") self.window_op = get_window_op(window_method) @@ -646,7 +622,6 @@ class NaSwinAttention(NaMMAttention): vid_qkv, txt_qkv = self.proj_qkv(vid, txt) - # re-org the input seq for window attn cache_win = cache.namespace(f"{self.window_method}_{self.window}_sd3") def make_window(x: torch.Tensor): @@ -654,7 +629,7 @@ class NaSwinAttention(NaMMAttention): window_slices = self.window_op((t, h, w), self.window) return [x[st, sh, sw] for (st, sh, sw) in window_slices] - window_partition, window_reverse, window_shape, window_count = cache_win( + window_partition, window_reverse, window_shape, window_count, vid_len_win_list, window_count_list = cache_win( "win_transform", lambda: window_idx(vid_shape, make_window), ) @@ -674,23 +649,21 @@ class NaSwinAttention(NaMMAttention): vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1)) txt_len = txt_len.to(window_count.device) - # window rope if self.rope: if self.version_7b: vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) elif self.rope.mm: - # repeat text q and k for window mmrope _, num_h, _ = txt_q.shape txt_q_repeat = txt_q.flatten(1, 2) txt_q_repeat = unflatten(txt_q_repeat, txt_shape) - txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count)] + txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count_list)] txt_q_repeat = list(chain(*txt_q_repeat)) txt_q_repeat, txt_shape_repeat = flatten(txt_q_repeat) txt_q_repeat = txt_q_repeat.reshape(txt_q_repeat.shape[0], num_h, self.head_dim) txt_k_repeat = txt_k.flatten(1, 2) txt_k_repeat = unflatten(txt_k_repeat, txt_shape) - txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count)] + txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count_list)] txt_k_repeat = list(chain(*txt_k_repeat)) txt_k_repeat, _ = flatten(txt_k_repeat) txt_k_repeat = txt_k_repeat.reshape(txt_k_repeat.shape[0], num_h, self.head_dim) @@ -702,7 +675,11 @@ class NaSwinAttention(NaMMAttention): vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count)) - all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win) + txt_len_win_list = cache_win( + "txt_len_list", + lambda: [txt_len for txt_len, window_count in zip(txt_len.tolist(), window_count_list) for _ in range(window_count)], + ) + all_len_win = cache_win("all_len", lambda: [vid_len + txt_len for vid_len, txt_len in zip(vid_len_win_list, txt_len_win_list)]) concat_win, unconcat_win = cache_win( "mm_pnp", lambda: repeat_concat_idx(vid_len_win, txt_len, window_count) ) @@ -711,12 +688,8 @@ class NaSwinAttention(NaMMAttention): k=concat_win(vid_k, txt_k), v=concat_win(vid_v, txt_v), heads=self.heads, skip_reshape=True, skip_output_reshape=True, - cu_seqlens_q=cache_win( - "vid_seqlens_q", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int() - ), - cu_seqlens_k=cache_win( - "vid_seqlens_k", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int() - ), + cu_seqlens_q=cache_win("vid_seqlens_q", lambda: cumulative_lengths(all_len_win)), + cu_seqlens_k=cache_win("vid_seqlens_k", lambda: cumulative_lengths(all_len_win)), ) vid_out, txt_out = unconcat_win(out) @@ -766,11 +739,11 @@ class SwiGLUMLP(nn.Module): return self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x)) def get_mlp(mlp_type: Optional[str] = "normal"): - # 3b and 7b uses different mlp types if mlp_type == "normal": return MLP - elif mlp_type == "swiglu": + if mlp_type == "swiglu": return SwiGLUMLP + raise ValueError(f"Unknown SeedVR2 MLP type: {mlp_type}") class NaMMSRTransformerBlock(nn.Module): def __init__( @@ -792,11 +765,12 @@ class NaMMSRTransformerBlock(nn.Module): rope_type: str, rope_dim: int, is_last_layer: bool, + window: Union[int, Tuple[int, int, int]], + window_method: str, + version: bool, device, dtype, operations, - **kwargs, ): super().__init__() - version = kwargs.get("version", False) dim = MMArg(vid_dim, txt_dim) self.attn_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, device=device, dtype=dtype) @@ -811,8 +785,8 @@ class NaMMSRTransformerBlock(nn.Module): rope_type=rope_type, rope_dim=rope_dim, shared_weights=shared_weights, - window=kwargs.pop("window", None), - window_method=kwargs.pop("window_method", None), + window=window, + window_method=window_method, version=version, device=device, dtype=dtype, operations=operations ) @@ -930,12 +904,14 @@ class NaPatchOut(PatchOut): self, vid: torch.FloatTensor, # l c vid_shape: torch.LongTensor, - cache: Cache = Cache(disable=True), + cache: Optional[Cache] = None, vid_shape_before_patchify = None ) -> Tuple[ torch.FloatTensor, torch.LongTensor, ]: + if cache is None: + cache = Cache(disable=True) t, h, w = self.patch_size vid = self.proj(vid) @@ -971,7 +947,10 @@ class PatchIn(nn.Module): ) -> torch.Tensor: t, h, w = self.patch_size if t > 1: - assert vid.size(2) % t == 1 + if vid.size(2) % t != 1: + raise ValueError( + f"SeedVR2 patch input temporal size must satisfy T % {t} == 1, got {vid.size(2)}." + ) vid = torch.cat([vid[:, :, :1]] * (t - 1) + [vid], dim=2) b, c, Tt, Hh, Ww = vid.shape vid = vid.view(b, c, Tt // t, t, Hh // h, h, Ww // w, w).permute(0, 2, 4, 6, 3, 5, 7, 1).reshape(b, Tt // t, Hh // h, Ww // w, t * h * w * c) @@ -983,8 +962,10 @@ class NaPatchIn(PatchIn): self, vid: torch.Tensor, # l c vid_shape: torch.LongTensor, - cache: Cache = Cache(disable=True), + cache: Optional[Cache] = None, ) -> torch.Tensor: + if cache is None: + cache = Cache(disable=True) cache = cache.namespace("patch") vid_shape_before_patchify = cache("vid_shape_before_patchify", lambda: vid_shape) t, h, w = self.patch_size @@ -1012,10 +993,11 @@ class AdaSingle(nn.Module): dim: int, emb_dim: int, layers: List[str], - modes: List[str] = ["in", "out"], + modes: Tuple[str, ...] = ("in", "out"), device = None, dtype = None, ): - assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim" + if emb_dim != 6 * dim: + raise ValueError(f"SeedVR2 AdaSingle requires emb_dim == 6 * dim, got emb_dim={emb_dim}, dim={dim}.") super().__init__() self.dim = dim self.emb_dim = emb_dim @@ -1036,22 +1018,20 @@ class AdaSingle(nn.Module): emb: torch.FloatTensor, # b d layer: str, mode: str, - cache: Cache = Cache(disable=True), + cache: Optional[Cache] = None, branch_tag: str = "", hid_len: Optional[torch.LongTensor] = None, # b ) -> torch.FloatTensor: + if cache is None: + cache = Cache(disable=True) idx = self.layers.index(layer) emb = emb.reshape(emb.shape[0], -1, len(self.layers), 3)[:, :, idx, :] emb = expand_dims(emb, 1, hid.ndim + 1) if hid_len is not None: - slice_inputs = lambda x, dim: x emb = cache( f"emb_repeat_{idx}_{branch_tag}", - lambda: slice_inputs( - torch.repeat_interleave(emb, hid_len, dim=0), - dim=0, - ), + lambda: torch.repeat_interleave(emb, hid_len, dim=0), ) shiftA, scaleA, gateA = emb.unbind(-1) @@ -1069,7 +1049,7 @@ class AdaSingle(nn.Module): else: return hid.mul_(gateA) - raise NotImplementedError + raise ValueError(f"Unknown AdaSingle mode: {mode}") class TimeEmbedding(nn.Module): @@ -1117,7 +1097,8 @@ def flatten( torch.FloatTensor, # (L c) torch.LongTensor, # (b n) ]: - assert len(hid) > 0 + if len(hid) == 0: + raise ValueError("SeedVR2 flatten requires at least one tensor.") shape = torch.as_tensor([x.shape[:-1] for x in hid], device=hid[0].device) hid = torch.cat([x.flatten(0, -2) for x in hid]) return hid, shape @@ -1140,7 +1121,7 @@ class NaDiT(nn.Module): num_layers, mlp_type, vid_in_channels = 33, - vid_out_channels = 16, + vid_out_channels = SEEDVR2_LATENT_CHANNELS, vid_dim = 2560, txt_in_dim = 5120, heads = 20, @@ -1148,15 +1129,17 @@ class NaDiT(nn.Module): mm_layers = 10, expand_ratio = 4, qk_bias = False, - patch_size = [ 1,2,2 ], + patch_size = (1, 2, 2), rope_dim = 128, rope_type = "mmrope3d", vid_out_norm: Optional[str] = None, + image_model = None, device = None, dtype = None, operations = None, - **kwargs, ): + if image_model not in (None, "seedvr2"): + raise ValueError(f"SeedVR2 NaDiT expected image_model='seedvr2', got {image_model!r}.") self._7b_version = vid_dim == SEEDVR2_7B_VID_DIM if self._7b_version: rope_type = "rope3d" @@ -1212,14 +1195,13 @@ class NaDiT(nn.Module): rope_dim = rope_dim, window=window[i], window_method=window_method[i], + version = self._7b_version, is_last_layer=(i == num_layers - 1) and not self._7b_version, rope_type = rope_type, shared_weights=not ( (i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i] ), - version = self._7b_version, operations = operations, - **kwargs, **factory_kwargs ) for i in range(num_layers) @@ -1272,13 +1254,17 @@ class NaDiT(nn.Module): first = cond_or_uncond[0] return all(entry == first for entry in cond_or_uncond) + @staticmethod + def _check_seedvr2_video_latent(x, channels, name): + if x.ndim != 5: + raise ValueError(f"SeedVR2 expected {name} to be 5-D native latent, got shape {tuple(x.shape)}.") + if x.shape[1] != channels: + raise ValueError(f"SeedVR2 expected {name} channels to be {channels}, got shape {tuple(x.shape)}.") + return x + def _swap_pos_neg_halves(self, out, cond_or_uncond=None): if NaDiT._seedvr2_is_single_conditioning_branch(cond_or_uncond): return out - # ``dim=0`` is explicit on both calls. The contract is "split - # the batch axis into two halves and swap them"; making the - # axis load-bearing in source guards against silent drift if a - # future refactor reorders tensor axes. pos, neg = out.chunk(2, dim=0) return torch.cat([neg, pos], dim=0) @@ -1294,9 +1280,15 @@ class NaDiT(nn.Module): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) conditions = kwargs.get("condition") - b, tc, h, w = x.shape - x = x.view(b, 16, -1, h, w) - conditions = conditions.view(b, 17, -1, h, w) + if conditions is None: + raise ValueError("SeedVR2 requires conditioning latents from the SeedVR2Conditioning node.") + x = self._check_seedvr2_video_latent(x, SEEDVR2_LATENT_CHANNELS, "latent") + conditions = self._check_seedvr2_video_latent(conditions, SEEDVR2_LATENT_CHANNELS + 1, "conditioning") + b, _, t, h, w = x.shape + if conditions.shape[0] != b or conditions.shape[2:] != (t, h, w): + raise ValueError( + f"SeedVR2 conditioning shape must match latent batch/temporal/spatial dimensions; got latent {tuple(x.shape)} and conditioning {tuple(conditions.shape)}." + ) x = x.movedim(1, -1) conditions = conditions.movedim(1, -1) cache = Cache(disable=disable_cache) @@ -1361,7 +1353,6 @@ class NaDiT(nn.Module): vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify) vid = unflatten(vid, vid_shape) - out = torch.stack(vid) + out = torch.stack(vid) out = out.movedim(-1, 1) - out = out.reshape(out.shape[0], out.shape[1] * out.shape[2], out.shape[3], out.shape[4]) return self._swap_pos_neg_halves(out, transformer_options.get("cond_or_uncond")) diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index 5daab022a..3f23a4691 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -62,7 +62,6 @@ def tiled_vae( temporal_size=16, temporal_overlap=0, encode=True, - **kwargs, ): if x.ndim != 5: x = x.unsqueeze(2) @@ -166,8 +165,8 @@ def tiled_vae( if single_spatial_tile: result = tile_out[:, :, :target_d, :target_h, :target_w] - if result.device != x.device: - result = result.to(x.device).to(x.dtype) + if result.device != x.device or result.dtype != x.dtype: + result = result.to(device=x.device, dtype=x.dtype) if x.shape[2] == 1 and sf_t == 1: result = result.squeeze(2) bar.update(1) @@ -221,8 +220,8 @@ def tiled_vae( result.div_(count.clamp(min=1e-6)) - if result.device != x.device: - result = result.to(x.device).to(x.dtype) + if result.device != x.device or result.dtype != x.dtype: + result = result.to(device=x.device, dtype=x.dtype) if x.shape[2] == 1 and sf_t == 1: result = result.squeeze(2) @@ -256,15 +255,18 @@ class MemoryState(Enum): UNSET = 3 def get_cache_size(conv_module, input_len, pad_len, dim=0): - dilated_kernerl_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1 - output_len = (input_len + pad_len - dilated_kernerl_size) // conv_module.stride[dim] + 1 + dilated_kernel_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1 + output_len = (input_len + pad_len - dilated_kernel_size) // conv_module.stride[dim] + 1 remain_len = ( - input_len + pad_len - ((output_len - 1) * conv_module.stride[dim] + dilated_kernerl_size) + input_len + pad_len - ((output_len - 1) * conv_module.stride[dim] + dilated_kernel_size) ) - overlap_len = dilated_kernerl_size - conv_module.stride[dim] - cache_len = overlap_len + remain_len # >= 0 + overlap_len = dilated_kernel_size - conv_module.stride[dim] + cache_len = overlap_len + remain_len - assert output_len > 0 + if output_len <= 0: + raise ValueError( + f"SeedVR2 VAE cache input is too short for convolution: input_len={input_len}, pad_len={pad_len}." + ) return cache_len class DiagonalGaussianDistribution(object): @@ -294,52 +296,27 @@ class SpatialNorm(nn.Module): new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) return new_f -# partial implementation of diffusers's Attention for comfyui class Attention(nn.Module): def __init__( self, query_dim: int, - cross_attention_dim: Optional[int] = None, heads: int = 8, - kv_heads: Optional[int] = None, dim_head: int = 64, - dropout: float = 0.0, bias: bool = False, - upcast_softmax: bool = False, norm_num_groups: Optional[int] = None, spatial_norm_dim: Optional[int] = None, out_bias: bool = True, - scale_qk: bool = True, - only_cross_attention: bool = False, eps: float = 1e-5, rescale_output_factor: float = 1.0, residual_connection: bool = False, - out_dim: int = None, - pre_only=False, ): super().__init__() - self.inner_dim = out_dim if out_dim is not None else dim_head * heads - self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads - self.query_dim = query_dim - self.use_bias = bias - self.is_cross_attention = cross_attention_dim is not None - self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - self.upcast_softmax = upcast_softmax + self.inner_dim = dim_head * heads self.rescale_output_factor = rescale_output_factor self.residual_connection = residual_connection - self.dropout = dropout - self.fused_projections = False - self.out_dim = out_dim if out_dim is not None else query_dim - self.pre_only = pre_only - - self.scale_qk = scale_qk - self.scale = dim_head**-0.5 if self.scale_qk else 1.0 - - self.heads = out_dim // dim_head if out_dim is not None else heads - self.sliceable_head_dim = heads - - self.only_cross_attention = only_cross_attention + self.out_dim = query_dim + self.heads = heads if norm_num_groups is not None: self.group_norm = ops.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) @@ -351,37 +328,19 @@ class Attention(nn.Module): else: self.spatial_norm = None - self.norm_q = None - self.norm_k = None - - self.norm_cross = None self.to_q = ops.Linear(query_dim, self.inner_dim, bias=bias) - - if not self.only_cross_attention: - # only relevant for the `AddedKVProcessor` classes - self.to_k = ops.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) - self.to_v = ops.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) - else: - self.to_k = None - self.to_v = None - - if not self.pre_only: - self.to_out = nn.ModuleList([]) - self.to_out.append(ops.Linear(self.inner_dim, self.out_dim, bias=out_bias)) - self.to_out.append(nn.Dropout(dropout)) - else: - self.to_out = None + self.to_k = ops.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = ops.Linear(query_dim, self.inner_dim, bias=bias) + self.to_out = nn.ModuleList([]) + self.to_out.append(ops.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Identity()) self.optimized_vae_attention = vae_attention() - def __call__( + def forward( self, hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, - *args, - **kwargs, ) -> torch.Tensor: residual = hidden_states @@ -394,20 +353,14 @@ class Attention(nn.Module): batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) + batch_size = hidden_states.shape[0] if self.group_norm is not None: hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = self.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - - key = self.to_k(encoder_hidden_states) - value = self.to_v(encoder_hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // self.heads @@ -417,25 +370,18 @@ class Attention(nn.Module): key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) - if self.norm_q is not None: - query = self.norm_q(query) - if self.norm_k is not None: - key = self.norm_k(key) - - if input_ndim == 4 and encoder_hidden_states is hidden_states and attention_mask is None and self.heads == 1: + if input_ndim == 4 and self.heads == 1: query = query.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width) key = key.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width) value = value.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width) hidden_states = self.optimized_vae_attention(query, key, value).reshape(batch_size, self.heads, head_dim, height * width).transpose(2, 3) else: - hidden_states = optimized_attention(query, key, value, heads = self.heads, mask = attention_mask, skip_reshape=True, skip_output_reshape=True) + hidden_states = optimized_attention(query, key, value, heads = self.heads, skip_reshape=True, skip_output_reshape=True) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) hidden_states = hidden_states.to(query.dtype) - # linear proj hidden_states = self.to_out[0](hidden_states) - # dropout hidden_states = self.to_out[1](hidden_states) if input_ndim == 4: @@ -471,7 +417,10 @@ def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: memory_occupy = x.numel() * x.element_size() / 1024**3 if isinstance(norm_layer, ops.GroupNorm) and memory_occupy > get_norm_limit(): num_chunks = min(BYTEDANCE_GN_CHUNKS_FP16 if x.element_size() == 2 else BYTEDANCE_GN_CHUNKS_FP32, norm_layer.num_groups) - assert norm_layer.num_groups % num_chunks == 0 + if norm_layer.num_groups % num_chunks != 0: + raise ValueError( + f"SeedVR2 VAE GroupNorm groups must divide chunks: groups={norm_layer.num_groups}, chunks={num_chunks}." + ) num_groups_per_chunk = norm_layer.num_groups // num_chunks x = list(x.chunk(num_chunks, dim=1)) @@ -485,14 +434,15 @@ def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: x = norm_layer(x) x = x.reshape((b, t, x.size(1), x.size(2), x.size(3))).transpose(1, 2) return x.to(input_dtype) - raise NotImplementedError + raise TypeError(f"SeedVR2 VAE unsupported norm layer type: {type(norm_layer).__name__}") _receptive_field_t = Literal["half", "full"] def extend_head(tensor, times: int = 2, memory = None): if memory is not None: return torch.cat((memory.to(tensor), tensor), dim=2) - assert times >= 0, "Invalid input for function 'extend_head'!" + if times < 0: + raise ValueError(f"SeedVR2 VAE extend_head expected times >= 0, got {times}.") if times == 0: return tensor else: @@ -547,13 +497,11 @@ class InflatedCausalConv3d(ops.Conv3d): padding=(0, 0, 0, 0, 0, 0), prev_cache=None, ): - # Compatible with no limit. if math.isinf(self.memory_limit): if prev_cache is not None: x = torch.cat([prev_cache, x], dim=split_dim - 1) return super().forward(x) - # Compute tensor shape after concat & padding. shape = list(x.size()) if prev_cache is not None: shape[split_dim - 1] += prev_cache.size(split_dim - 1) @@ -597,16 +545,19 @@ class InflatedCausalConv3d(ops.Conv3d): next_cache = None cache_len = cache.size(split_dim) if cache is not None else 0 - next_catch_size = get_cache_size( + next_cache_size = get_cache_size( conv_module=self, input_len=x[idx].size(split_dim) + cache_len, pad_len=pad_len, dim=split_dim - 2, ) - if next_catch_size != 0: - assert next_catch_size <= x[idx].size(split_dim) + if next_cache_size != 0: + if next_cache_size > x[idx].size(split_dim): + raise ValueError( + f"SeedVR2 VAE cache size {next_cache_size} exceeds split size {x[idx].size(split_dim)}." + ) next_cache = ( - x[idx].transpose(0, split_dim)[-next_catch_size:].transpose(0, split_dim) + x[idx].transpose(0, split_dim)[-next_cache_size:].transpose(0, split_dim) ) x[idx] = self.memory_limit_conv( @@ -627,7 +578,8 @@ class InflatedCausalConv3d(ops.Conv3d): memory_state: MemoryState = MemoryState.UNSET, memory_cache = None, ) -> Tensor: - assert memory_state != MemoryState.UNSET + if memory_state == MemoryState.UNSET: + raise ValueError("SeedVR2 VAE convolution requires an explicit MemoryState.") if memory_cache is None: memory_cache = {} if memory_state != MemoryState.ACTIVE: @@ -677,9 +629,8 @@ class InflatedCausalConv3d(ops.Conv3d): input, cache_size=cache_size, memory=memory, times=self.temporal_padding * 2 ) - # Single GPU inference - simplified memory management if ( - memory_state in [MemoryState.INITIALIZING, MemoryState.ACTIVE] # use_slicing + memory_state in [MemoryState.INITIALIZING, MemoryState.ACTIVE] and cache_size != 0 ): if cache_size > input[-1].size(2) and cache is not None and len(input) == 1: @@ -690,7 +641,6 @@ class InflatedCausalConv3d(ops.Conv3d): padding = tuple(x for x in reversed(self.padding) for _ in range(2)) for i in range(len(input)): - # Prepare cache for next input slice. next_cache = None cache_size = 0 if i < len(input) - 1: @@ -700,17 +650,16 @@ class InflatedCausalConv3d(ops.Conv3d): if cache_size > input[i].size(2) and cache is not None: input[i] = torch.cat([cache, input[i]], dim=2) cache = None - assert cache_size <= input[i].size(2), f"{cache_size} > {input[i].size(2)}" + if cache_size > input[i].size(2): + raise ValueError(f"SeedVR2 VAE cache size {cache_size} exceeds input length {input[i].size(2)}.") next_cache = input[i][:, :, -cache_size:] - # Conv forward for this input slice. input[i] = self.memory_limit_conv( input[i], padding=padding, prev_cache=cache ) - # Update cache. cache = next_cache return input[0] if squeeze_out else input @@ -729,7 +678,6 @@ class Upsample3D(nn.Module): inflation_mode = "tail", temporal_up: bool = False, spatial_up: bool = True, - **kwargs, ): super().__init__() self.channels = channels @@ -760,9 +708,9 @@ class Upsample3D(nn.Module): hidden_states: torch.FloatTensor, memory_state=None, memory_cache=None, - **kwargs, ) -> torch.FloatTensor: - assert hidden_states.shape[1] == self.channels + if hidden_states.shape[1] != self.channels: + raise ValueError(f"SeedVR2 upsample expected {self.channels} channels, got {hidden_states.shape[1]}.") hidden_states = self.upscale_conv(hidden_states) b, channels, f, h, w = hidden_states.shape @@ -785,8 +733,6 @@ class Upsample3D(nn.Module): class Downsample3D(nn.Module): - """A 3D downsampling layer with an optional convolution.""" - def __init__( self, channels, @@ -794,7 +740,6 @@ class Downsample3D(nn.Module): inflation_mode = "tail", spatial_down: bool = False, temporal_down: bool = False, - **kwargs, ): super().__init__() self.channels = channels @@ -823,20 +768,17 @@ class Downsample3D(nn.Module): hidden_states: torch.FloatTensor, memory_state = None, memory_cache = None, - **kwargs, ) -> torch.FloatTensor: - assert hidden_states.shape[1] == self.channels - - if hasattr(self, "norm") and self.norm is not None: - # [Overridden] change to causal norm. - hidden_states = causal_norm_wrapper(self.norm, hidden_states) + if hidden_states.shape[1] != self.channels: + raise ValueError(f"SeedVR2 downsample expected {self.channels} channels, got {hidden_states.shape[1]}.") if self.spatial_down: pad = (0, 1, 0, 1) hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) - assert hidden_states.shape[1] == self.channels + if hidden_states.shape[1] != self.channels: + raise ValueError(f"SeedVR2 downsample expected {self.channels} channels after padding, got {hidden_states.shape[1]}.") hidden_states = self.conv(hidden_states, memory_state=memory_state, memory_cache=memory_cache) @@ -848,7 +790,6 @@ class ResnetBlock3D(nn.Module): self, in_channels: int, out_channels: Optional[int] = None, - dropout: float = 0.0, temb_channels: int = 512, groups: int = 32, groups_out: Optional[int] = None, @@ -857,7 +798,6 @@ class ResnetBlock3D(nn.Module): skip_time_act: bool = False, inflation_mode = "tail", time_receptive_field: _receptive_field_t = "half", - **kwargs, ): super().__init__() self.in_channels = in_channels @@ -866,15 +806,14 @@ class ResnetBlock3D(nn.Module): self.skip_time_act = skip_time_act self.nonlinearity = nn.SiLU() if temb_channels is not None: - self.time_emb_proj = ops.Linear(temb_channels, out_channels) + self.time_emb_proj = ops.Linear(temb_channels, self.out_channels) else: self.time_emb_proj = None self.norm1 = ops.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) if groups_out is None: groups_out = groups - self.norm2 = ops.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) - self.use_in_shortcut = self.in_channels != out_channels - self.dropout = torch.nn.Dropout(dropout) + self.norm2 = ops.GroupNorm(num_groups=groups_out, num_channels=self.out_channels, eps=eps, affine=True) + self.use_in_shortcut = self.in_channels != self.out_channels self.conv1 = InflatedCausalConv3d( self.in_channels, self.out_channels, @@ -886,7 +825,7 @@ class ResnetBlock3D(nn.Module): self.conv2 = InflatedCausalConv3d( self.out_channels, - out_channels, + self.out_channels, kernel_size=3, stride=1, padding=1, @@ -897,7 +836,7 @@ class ResnetBlock3D(nn.Module): if self.use_in_shortcut: self.conv_shortcut = InflatedCausalConv3d( self.in_channels, - out_channels, + self.out_channels, kernel_size=1, stride=1, padding=0, @@ -905,9 +844,7 @@ class ResnetBlock3D(nn.Module): inflation_mode=inflation_mode, ) - def forward( - self, input_tensor, temb, memory_state = None, memory_cache = None, **kwargs - ): + def forward(self, input_tensor, temb, memory_state = None, memory_cache = None): hidden_states = input_tensor hidden_states = causal_norm_wrapper(self.norm1, hidden_states) @@ -928,7 +865,6 @@ class ResnetBlock3D(nn.Module): hidden_states = self.nonlinearity(hidden_states) - hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states, memory_state=memory_state, memory_cache=memory_cache) if self.conv_shortcut is not None: @@ -944,7 +880,6 @@ class DownEncoderBlock3D(nn.Module): self, in_channels: int, out_channels: int, - dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_groups: int = 32, @@ -957,28 +892,23 @@ class DownEncoderBlock3D(nn.Module): ): super().__init__() resnets = [] - temporal_modules = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( - # [Override] Replace module. ResnetBlock3D( in_channels=in_channels, out_channels=out_channels, temb_channels=None, eps=resnet_eps, groups=resnet_groups, - dropout=dropout, output_scale_factor=output_scale_factor, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) ) - temporal_modules.append(nn.Identity()) self.resnets = nn.ModuleList(resnets) - self.temporal_modules = nn.ModuleList(temporal_modules) if add_downsample: self.downsamplers = nn.ModuleList( @@ -1000,11 +930,9 @@ class DownEncoderBlock3D(nn.Module): hidden_states: torch.FloatTensor, memory_state = None, memory_cache = None, - **kwargs, ) -> torch.FloatTensor: - for resnet, temporal in zip(self.resnets, self.temporal_modules): + for resnet in self.resnets: hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state, memory_cache=memory_cache) - hidden_states = temporal(hidden_states) if self.downsamplers is not None: for downsampler in self.downsamplers: @@ -1018,7 +946,6 @@ class UpDecoderBlock3D(nn.Module): self, in_channels: int, out_channels: int, - dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_groups: int = 32, @@ -1032,33 +959,26 @@ class UpDecoderBlock3D(nn.Module): ): super().__init__() resnets = [] - temporal_modules = [] for i in range(num_layers): input_channels = in_channels if i == 0 else out_channels resnets.append( - # [Override] Replace module. ResnetBlock3D( in_channels=input_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, - dropout=dropout, output_scale_factor=output_scale_factor, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) ) - temporal_modules.append(nn.Identity()) - self.resnets = nn.ModuleList(resnets) - self.temporal_modules = nn.ModuleList(temporal_modules) if add_upsample: - # [Override] Replace module & use learnable upsample self.upsamplers = nn.ModuleList( [ Upsample3D( @@ -1080,9 +1000,8 @@ class UpDecoderBlock3D(nn.Module): memory_state=None, memory_cache=None, ) -> torch.FloatTensor: - for resnet, temporal in zip(self.resnets, self.temporal_modules): + for resnet in self.resnets: hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state, memory_cache=memory_cache) - hidden_states = temporal(hidden_states) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -1096,7 +1015,6 @@ class UNetMidBlock3D(nn.Module): self, in_channels: int, temb_channels: int, - dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", # default, spatial @@ -1111,16 +1029,13 @@ class UNetMidBlock3D(nn.Module): resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) self.add_attention = add_attention - # there is always at least one resnet resnets = [ - # [Override] Replace module. ResnetBlock3D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, - dropout=dropout, output_scale_factor=output_scale_factor, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, @@ -1148,7 +1063,6 @@ class UNetMidBlock3D(nn.Module): ), residual_connection=True, bias=True, - upcast_softmax=True, ) ) else: @@ -1161,7 +1075,6 @@ class UNetMidBlock3D(nn.Module): temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, - dropout=dropout, output_scale_factor=output_scale_factor, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, @@ -1172,7 +1085,7 @@ class UNetMidBlock3D(nn.Module): self.resnets = nn.ModuleList(resnets) def forward(self, hidden_states, temb=None, memory_state=None, memory_cache=None): - video_length, frame_height, frame_width = hidden_states.size()[-3:] + video_length = hidden_states.size(2) hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state, memory_cache=memory_cache) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: @@ -1195,7 +1108,6 @@ class Encoder3D(nn.Module): layers_per_block: int = 2, norm_num_groups: int = 32, mid_block_add_attention=True, - # [Override] add temporal down num temporal_down_num: int = 2, inflation_mode = "tail", time_receptive_field: _receptive_field_t = "half", @@ -1216,17 +1128,15 @@ class Encoder3D(nn.Module): self.mid_block = None self.down_blocks = nn.ModuleList([]) - # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 - # [Override] to support temporal down block design is_temporal_down_block = i >= len(block_out_channels) - self.temporal_down_num - 1 - # Note: take the last ones - assert down_block_type == "DownEncoderBlock3D" + if down_block_type != "DownEncoderBlock3D": + raise ValueError(f"SeedVR2 encoder only supports DownEncoderBlock3D, got {down_block_type}.") down_block = DownEncoderBlock3D( num_layers=self.layers_per_block, @@ -1242,7 +1152,6 @@ class Encoder3D(nn.Module): ) self.down_blocks.append(down_block) - # mid self.mid_block = UNetMidBlock3D( in_channels=block_out_channels[-1], resnet_eps=1e-6, @@ -1256,7 +1165,6 @@ class Encoder3D(nn.Module): time_receptive_field=time_receptive_field, ) - # out self.conv_norm_out = ops.GroupNorm( num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6 ) @@ -1274,17 +1182,13 @@ class Encoder3D(nn.Module): memory_state = None, memory_cache = None, ) -> torch.FloatTensor: - r"""The forward method of the `Encoder` class.""" sample = sample.to(next(self.parameters()).device) sample = self.conv_in(sample, memory_state=memory_state, memory_cache=memory_cache) - # down for down_block in self.down_blocks: sample = down_block(sample, memory_state=memory_state, memory_cache=memory_cache) - # middle sample = self.mid_block(sample, memory_state=memory_state, memory_cache=memory_cache) - # post-process sample = causal_norm_wrapper(self.conv_norm_out, sample) sample = self.conv_act(sample) sample = self.conv_out(sample, memory_state=memory_state, memory_cache=memory_cache) @@ -1303,7 +1207,6 @@ class Decoder3D(nn.Module): layers_per_block: int = 2, norm_num_groups: int = 32, mid_block_add_attention=True, - # [Override] add temporal up block inflation_mode = "tail", time_receptive_field: _receptive_field_t = "half", temporal_up_num: int = 2, @@ -1326,7 +1229,6 @@ class Decoder3D(nn.Module): temb_channels = None - # mid self.mid_block = UNetMidBlock3D( in_channels=block_out_channels[-1], resnet_eps=1e-6, @@ -1340,7 +1242,6 @@ class Decoder3D(nn.Module): time_receptive_field=time_receptive_field, ) - # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): @@ -1349,7 +1250,8 @@ class Decoder3D(nn.Module): is_final_block = i == len(block_out_channels) - 1 is_temporal_up_block = i < self.temporal_up_num - assert up_block_type == "UpDecoderBlock3D" + if up_block_type != "UpDecoderBlock3D": + raise ValueError(f"SeedVR2 decoder only supports UpDecoderBlock3D, got {up_block_type}.") up_block = UpDecoderBlock3D( num_layers=self.layers_per_block + 1, in_channels=prev_output_channel, @@ -1365,7 +1267,6 @@ class Decoder3D(nn.Module): self.up_blocks.append(up_block) prev_output_channel = output_channel - # out self.conv_norm_out = ops.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 ) @@ -1375,7 +1276,6 @@ class Decoder3D(nn.Module): ) - # Note: Just copy from Decoder. def forward( self, sample: torch.FloatTensor, @@ -1388,15 +1288,12 @@ class Decoder3D(nn.Module): sample = self.conv_in(sample, memory_state=memory_state, memory_cache=memory_cache) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype - # middle sample = self.mid_block(sample, latent_embeds, memory_state=memory_state, memory_cache=memory_cache) sample = sample.to(upscale_dtype) - # up for up_block in self.up_blocks: sample = up_block(sample, latent_embeds, memory_state=memory_state, memory_cache=memory_cache) - # post-process sample = causal_norm_wrapper(self.conv_norm_out, sample) sample = self.conv_act(sample) sample = self.conv_out(sample, memory_state=memory_state, memory_cache=memory_cache) @@ -1415,8 +1312,6 @@ class VideoAutoencoderKL(nn.Module): inflation_mode = "pad", time_receptive_field: _receptive_field_t = "full", slicing_sample_min_size = BYTEDANCE_SLICING_SAMPLE_MIN, - *args, - **kwargs, ): self.slicing_sample_min_size = slicing_sample_min_size self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num) @@ -1425,7 +1320,6 @@ class VideoAutoencoderKL(nn.Module): up_block_types = ("UpDecoderBlock3D",) * 4 super().__init__() - # pass init params to Encoder self.encoder = Encoder3D( in_channels=in_channels, out_channels=latent_channels, @@ -1433,13 +1327,11 @@ class VideoAutoencoderKL(nn.Module): block_out_channels=block_out_channels, layers_per_block=layers_per_block, norm_num_groups=norm_num_groups, - # [Override] add temporal_down_num parameter temporal_down_num=temporal_scale_num, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) - # pass init params to Decoder self.decoder = Decoder3D( in_channels=latent_channels, out_channels=out_channels, @@ -1447,7 +1339,6 @@ class VideoAutoencoderKL(nn.Module): block_out_channels=block_out_channels, layers_per_block=layers_per_block, norm_num_groups=norm_num_groups, - # [Override] add temporal_up_num parameter temporal_up_num=temporal_scale_num, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, @@ -1489,11 +1380,10 @@ class VideoAutoencoderKL(nn.Module): return output.to(z.device) def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: - sp_size =1 - if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size: + if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size: memory_cache = {} split_size = max( - self.slicing_sample_min_size * sp_size, + self.slicing_sample_min_size, getattr(self, "temporal_downsample_factor", 1), ) x_slices = list(x[:, :, 1:].split(split_size=split_size, dim=2)) @@ -1518,10 +1408,9 @@ class VideoAutoencoderKL(nn.Module): return self._encode(x) def slicing_decode(self, z: torch.Tensor) -> torch.Tensor: - sp_size = 1 - if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size: + if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size: memory_cache = {} - z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2) + z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size, dim=2) decoded_slices = [ self._decode( torch.cat((z[:, :, :1], z_slices[0]), dim=2), @@ -1538,33 +1427,28 @@ class VideoAutoencoderKL(nn.Module): else: return self._decode(z) - def forward( - self, x: torch.FloatTensor, mode: Literal["encode", "decode", "all"] = "all", **kwargs - ): - # x: [b c t h w] + def forward(self, x: torch.FloatTensor, mode: Literal["encode", "decode", "all"] = "all"): def _unwrap(value): return value[0] if isinstance(value, tuple) else value if mode == "encode": return _unwrap(self.encode(x)) - elif mode == "decode": + if mode == "decode": return _unwrap(self.decode_(x)) - else: + if mode == "all": latent = _unwrap(self.encode(x)) return _unwrap(self.decode_(latent)) + raise ValueError(f"Unknown SeedVR2 VAE forward mode: {mode}") class VideoAutoencoderKLWrapper(VideoAutoencoderKL): def __init__( self, - *args, spatial_downsample_factor = 8, temporal_downsample_factor = 4, - **kwargs, ): self.spatial_downsample_factor = spatial_downsample_factor self.temporal_downsample_factor = temporal_downsample_factor - self.enable_tiling = False - super().__init__(*args, **kwargs) + super().__init__() self.set_memory_limit(BYTEDANCE_VAE_CONV_MEM_GIB, BYTEDANCE_VAE_NORM_MEM_GIB) def forward(self, x: torch.FloatTensor): @@ -1581,7 +1465,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): z = p.squeeze(2) return z, p - def encode(self, x, orig_dims=None): + def encode(self, x): z, _ = self._encode_with_raw_latent(x) return z @@ -1594,26 +1478,27 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): ) if z.ndim == 5: - b, c, t_latent, h, w = z.shape - if c != 16: + _, c, _, _, _ = z.shape + if c != SEEDVR2_LATENT_CHANNELS: raise RuntimeError( "SeedVR2 VideoAutoencoderKLWrapper.decode: 5-D latent input must " - f"have 16 channels; got shape {tuple(z.shape)}." + f"have {SEEDVR2_LATENT_CHANNELS} channels; got shape {tuple(z.shape)}." ) latent = z elif z.ndim == 4: b, tc, h, w = z.shape - if tc % 16 != 0: + if tc % SEEDVR2_LATENT_CHANNELS != 0: raise RuntimeError( "SeedVR2 VideoAutoencoderKLWrapper.decode: 4-D latent input must " - "use collapsed channel layout (B, 16*T, H, W); " + f"use collapsed channel layout (B, {SEEDVR2_LATENT_CHANNELS}*T, H, W); " f"got shape {tuple(z.shape)}." ) - latent = z.reshape(b, 16, -1, h, w) + latent = z.reshape(b, SEEDVR2_LATENT_CHANNELS, -1, h, w) else: raise RuntimeError( "SeedVR2 VideoAutoencoderKLWrapper.decode: latent input must be " - "4-D collapsed (B, 16*T, H, W) or 5-D (B, 16, T, H, W); " + f"4-D collapsed (B, {SEEDVR2_LATENT_CHANNELS}*T, H, W) or " + f"5-D (B, {SEEDVR2_LATENT_CHANNELS}, T, H, W); " f"got shape {tuple(z.shape)}." ) scale = BYTEDANCE_VAE_SCALING_FACTOR @@ -1621,10 +1506,11 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): latent = latent / scale + shift self.device = latent.device - self.enable_tiling = seedvr2_tiling.get("enable_tiling", False) + enable_tiling = seedvr2_tiling.get("enable_tiling", False) - if self.enable_tiling: + if enable_tiling: decode_seedvr2_args = dict(seedvr2_tiling) + decode_seedvr2_args.pop("enable_tiling", None) tile_h, tile_w = decode_seedvr2_args.get("tile_size", (512, 512)) ov_h, ov_w = decode_seedvr2_args.get("tile_overlap", (64, 64)) decode_seedvr2_args["tile_overlap"] = ( @@ -1641,7 +1527,6 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): else: x = super().decode_(latent) - # ensure even dims for save video h, w = x.shape[-2:] w2 = w - (w % 2) h2 = h - (h % 2) @@ -1693,7 +1578,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): if samples.ndim == 4: samples = samples.unsqueeze(2) samples = samples.contiguous() - samples = samples * 0.9152 + samples = samples * BYTEDANCE_VAE_SCALING_FACTOR return samples def comfy_memory_used_decode(self, shape): @@ -1707,15 +1592,15 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): # plus int64 sort indices dominate peak memory, not the VAE weight dtype. if len(shape) == 5: candidates = [] - if shape[1] == 16: + if shape[1] == SEEDVR2_LATENT_CHANNELS: candidates.append((shape[2], shape[3], shape[4])) - if shape[-1] == 16: + if shape[-1] == SEEDVR2_LATENT_CHANNELS: candidates.append((shape[1], shape[2], shape[3])) if len(candidates) == 0: candidates.append((shape[2], shape[3], shape[4])) pixels = max(output_pixels(*candidate) for candidate in candidates) elif len(shape) == 4: - latent_t = max(1, (shape[1] + 15) // 16) + latent_t = max(1, (shape[1] + SEEDVR2_LATENT_CHANNELS - 1) // SEEDVR2_LATENT_CHANNELS) pixels = output_pixels(latent_t, shape[2], shape[3]) else: pixels = output_pixels(1, shape[-2], shape[-1]) diff --git a/comfy/model_base.py b/comfy/model_base.py index 0004b339a..4b02e5bb4 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -933,7 +933,8 @@ class HunyuanDiT(BaseModel): class SeedVR2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): - super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT) + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.seedvr.model.NaDiT) + def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) condition = kwargs.get("condition", None) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index bf44b832c..97673988b 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -598,43 +598,34 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): return dit_config - if "{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix) in state_dict_keys and state_dict["{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix)].shape[1] == 3072: # seedvr2 7b + seedvr2_7b_separate_key = "{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix) + if seedvr2_7b_separate_key in state_dict_keys and state_dict[seedvr2_7b_separate_key].shape[1] == 3072: # seedvr2 7b dit_config = {} dit_config["image_model"] = "seedvr2" dit_config["vid_dim"] = 3072 dit_config["heads"] = 24 dit_config["num_layers"] = 36 - # 7B uses non-shared MMModule layout (separate ``vid.`` / ``txt.`` - # submodules) at EVERY block — verified by inspecting the 7B - # state_dict at ``blocks.31.ada.txt.attn_gate`` (txt. prefix means - # ``MMModule.shared_weights=False``). Native NaDiT computes - # per-block ``shared_weights = not (i < mm_layers)``, so to keep - # every block non-shared we set ``mm_layers = num_layers``. - # Without this, blocks at index >= mm_layers (default 10) try to - # load ``blocks.N.*.all.*`` keys that don't exist in the file, - # silently miss-load → all-black output. + # This checkpoint uses separate vid/txt MMModule keys in every block. dit_config["mm_layers"] = 36 dit_config["norm_eps"] = 1e-5 dit_config["rope_type"] = "rope3d" dit_config["rope_dim"] = 64 dit_config["mlp_type"] = "normal" return dit_config - elif "{}blocks.35.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b + if "{}blocks.35.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b dit_config = {} dit_config["image_model"] = "seedvr2" dit_config["vid_dim"] = 3072 dit_config["heads"] = 24 dit_config["num_layers"] = 36 - # This checkpoint layout carries shared ``all.`` MMModule keys. - # Preserve the historical split: the initial blocks use separate - # vid/txt modules, later blocks use shared modules. + # This checkpoint uses shared all.* MMModule keys after the initial blocks. dit_config["mm_layers"] = 10 dit_config["norm_eps"] = 1e-5 dit_config["rope_type"] = "rope3d" dit_config["rope_dim"] = 64 dit_config["mlp_type"] = "swiglu" return dit_config - elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b + if "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b dit_config = {} dit_config["image_model"] = "seedvr2" dit_config["vid_dim"] = 2560 @@ -1150,8 +1141,22 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): unet_config["heatmap_head"] = True return unet_config +def normalize_seedvr2_unet_config(unet_config): + if unet_config.get("image_model") != "seedvr2" or "num_heads" not in unet_config: + return unet_config + + unet_config = dict(unet_config) + num_heads = unet_config.pop("num_heads") + if "heads" in unet_config and unet_config["heads"] != num_heads: + raise ValueError( + f"SeedVR2 config has conflicting heads={unet_config['heads']} and num_heads={num_heads}." + ) + unet_config["heads"] = num_heads + return unet_config + def model_config_from_unet_config(unet_config, state_dict=None, unet_key_prefix=""): + unet_config = normalize_seedvr2_unet_config(unet_config) for model_config in comfy.supported_models.models: if model_config.matches(unet_config, state_dict, unet_key_prefix=unet_key_prefix): return model_config(unet_config) diff --git a/comfy/sd.py b/comfy/sd.py index 06c6196d3..20726d782 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -472,8 +472,7 @@ class VAE: def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None): is_seedvr2_vae = "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd if not is_seedvr2_vae and 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format - if metadata is None or metadata.get("keep_diffusers_format") != "true": - sd = diffusers_convert.convert_vae_state_dict(sd) + sd = diffusers_convert.convert_vae_state_dict(sd) if model_management.is_amd(): VAE_KL_MEM_RATIO = 2.73 @@ -549,7 +548,7 @@ class VAE: self.latent_channels = 16 elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2 self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper() - self.latent_channels = 16 + self.latent_channels = comfy.ldm.seedvr.vae.SEEDVR2_LATENT_CHANNELS self.latent_dim = 3 self.disable_offload = True self.memory_used_decode = lambda shape, dtype: self.first_stage_model.comfy_memory_used_decode(shape) @@ -1074,6 +1073,20 @@ class VAE: out = self.first_stage_model.encode_tiled(x, **kwargs) return out.to(device=self.output_device, dtype=self.vae_output_dtype()) + def _owned_tiled_args(self, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): + args = {} + if tile_x is not None: + args["tile_x"] = tile_x + if tile_y is not None: + args["tile_y"] = tile_y + if overlap is not None: + args["overlap"] = overlap + if tile_t is not None: + args["tile_t"] = tile_t + if overlap_t is not None: + args["overlap_t"] = overlap_t + return args + def decode(self, samples_in, vae_options={}): self.throw_exception_if_invalid() pixel_samples = None @@ -1153,18 +1166,7 @@ class VAE: with model_management.cuda_device_context(self.device): if self.handles_tiling and dims in (2, 3): - tiled_args = {} - if tile_x is not None: - tiled_args["tile_x"] = tile_x - if tile_y is not None: - tiled_args["tile_y"] = tile_y - if overlap is not None: - tiled_args["overlap"] = overlap - if tile_t is not None: - tiled_args["tile_t"] = tile_t - if overlap_t is not None: - tiled_args["overlap_t"] = overlap_t - output = self._decode_tiled_owned(samples, **tiled_args) + output = self._decode_tiled_owned(samples, **self._owned_tiled_args(tile_x, tile_y, overlap, tile_t, overlap_t)) elif dims == 1 or self.extra_1d_channel is not None: args.pop("tile_y") output = self.decode_tiled_1d(samples, **args) @@ -1269,18 +1271,7 @@ class VAE: samples = self.encode_tiled_(pixel_samples, **args) elif dims == 3: if self.handles_tiling: - tiled_args = {} - if tile_x is not None: - tiled_args["tile_x"] = tile_x - if tile_y is not None: - tiled_args["tile_y"] = tile_y - if overlap is not None: - tiled_args["overlap"] = overlap - if tile_t is not None: - tiled_args["tile_t"] = tile_t - if overlap_t is not None: - tiled_args["overlap_t"] = overlap_t - samples = self._encode_tiled_owned(pixel_samples, **tiled_args) + samples = self._encode_tiled_owned(pixel_samples, **self._owned_tiled_args(tile_x, tile_y, overlap, tile_t, overlap_t)) else: if tile_t is not None: tile_t_latent = max(2, self.downscale_ratio[0](tile_t)) @@ -1850,7 +1841,6 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl return (model, clip, vae) - def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False): sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 5c849358e..65cc79bb2 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1688,6 +1688,7 @@ class SeedVR2(supported_models_base.BASE): unet_config = { "image_model": "seedvr2" } + unet_extra_config = {} required_keys = { "{}positive_conditioning", "{}negative_conditioning", diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index bf5b3c15c..933b76237 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -19,21 +19,14 @@ from comfy.ldm.seedvr.constants import ( ) from torchvision.transforms import functional as TVF -from torchvision.transforms import Lambda from torchvision.transforms.functional import InterpolationMode -_SEEDVR2_INVALID_MODEL_MSG_PREFIX = ( - "SeedVR2Conditioning: model object does not match expected SeedVR2 structure" -) - -# Private sentinel for getattr default: distinguishes "attribute missing" -# from "attribute present but None" so the failure message is accurate. +_SEEDVR2_INVALID_MODEL_MSG_PREFIX = "SeedVR2Conditioning: model object does not match expected SeedVR2 structure" _ATTR_MISSING = object() def _resolve_seedvr2_diffusion_model(model): - """Resolve ``model.model.diffusion_model``, failing loud via the ``_ATTR_MISSING`` sentinel so each of the four modes (model/diffusion_model missing vs None) gives an accurate message.""" inner = getattr(model, "model", _ATTR_MISSING) if inner is _ATTR_MISSING: raise RuntimeError( @@ -59,15 +52,7 @@ def _resolve_seedvr2_diffusion_model(model): return diffusion_model -def get_conditions(latent, latent_blur): - t, h, w, c = latent.shape - cond = torch.ones([t, h, w, c + 1], device=latent.device, dtype=latent.dtype) - cond[:, ..., :-1] = latent_blur[:] - cond[:, ..., -1:] = 1.0 - return cond - def div_pad(image, factor): - height_factor, width_factor = factor height, width = image.shape[-2:] @@ -77,31 +62,25 @@ def div_pad(image, factor): if pad_height == 0 and pad_width == 0: return image - if isinstance(image, torch.Tensor): - padding = (0, pad_width, 0, pad_height) - image = torch.nn.functional.pad(image, padding, mode='constant', value=0.0) - - return image + padding = (0, pad_width, 0, pad_height) + return torch.nn.functional.pad(image, padding, mode='constant', value=0.0) def cut_videos(videos): t = videos.size(1) + if t < 1: + raise ValueError("SeedVR2Preprocess expected at least one frame.") if t == 1: return videos - if t <= 4 : - padding = [videos[:, -1].unsqueeze(1)] * (4 - t + 1) - padding = torch.cat(padding, dim=1) - videos = torch.cat([videos, padding], dim=1) - return videos - if (t - 1) % (4) == 0: - return videos - else: - padding = [videos[:, -1].unsqueeze(1)] * ( - 4 - ((t - 1) % (4)) - ) - padding = torch.cat(padding, dim=1) - videos = torch.cat([videos, padding], dim=1) - assert (videos.size(1) - 1) % (4) == 0 + if t <= 4: + padding = videos[:, -1:].repeat(1, 4 - t + 1, 1, 1, 1) + return torch.cat([videos, padding], dim=1) + if (t - 1) % 4 == 0: return videos + padding = videos[:, -1:].repeat(1, 4 - ((t - 1) % 4), 1, 1, 1) + videos = torch.cat([videos, padding], dim=1) + if (videos.size(1) - 1) % 4 != 0: + raise ValueError(f"SeedVR2Preprocess failed to pad video length to 4n+1; got {videos.size(1)} frames.") + return videos def _seedvr2_input_shorter_edge(images, node_name): if images.dim() == 4: @@ -136,8 +115,7 @@ def _seedvr2_pad(images, upscaled_shorter_edge, node_name): b, t, c, h, w = images.shape images = images.reshape(b * t, c, h, w) - clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0)) - images = clip(images) + images = torch.clamp(images, 0.0, 1.0) images = div_pad(images, (16, 16)) _, _, new_h, new_w = images.shape @@ -295,7 +273,6 @@ class SeedVR2PostProcessing(io.ComfyNode): def _color_transfer_chunked(cls, decoded_flat, reference_flat, output_device, color_correction_method): chunk_size = cls._estimate_color_correction_chunk_size(decoded_flat, color_correction_method) while True: - next_chunk_size = None try: return cls._run_color_transfer_chunks( decoded_flat, reference_flat, output_device, color_correction_method, chunk_size, @@ -307,9 +284,7 @@ class SeedVR2PostProcessing(io.ComfyNode): "SeedVR2PostProcessing: color correction OOM at one frame; " f"color_correction_method={color_correction_method}, shape={tuple(decoded_flat.shape)}." ) from e - next_chunk_size = max(1, chunk_size // SEEDVR2_OOM_BACKOFF_DIVISOR) - - chunk_size = next_chunk_size + chunk_size = max(1, chunk_size // SEEDVR2_OOM_BACKOFF_DIVISOR) @classmethod def _run_color_transfer_chunks(cls, decoded_flat, reference_flat, output_device, color_correction_method, chunk_size): @@ -392,10 +367,8 @@ class SeedVR2Conditioning(io.ComfyNode): io.Latent.Input("vae_conditioning", display_name="latent"), ], outputs=[ - io.Model.Output(display_name="model", tooltip="The SeedVR2 model, passed through."), io.Conditioning.Output(display_name="positive", tooltip="The positive conditioning for sampling."), io.Conditioning.Output(display_name="negative", tooltip="The negative conditioning for sampling."), - io.Latent.Output(display_name="latent", tooltip="The latent to denoise."), ], ) @@ -408,29 +381,30 @@ class SeedVR2Conditioning(io.ComfyNode): "SeedVR2Conditioning expects a 5-D VAE latent in Comfy " f"channel-first layout; got shape {tuple(vae_conditioning.shape)}." ) - if vae_conditioning.shape[-1] == SEEDVR2_LATENT_CHANNELS and vae_conditioning.shape[1] != SEEDVR2_LATENT_CHANNELS: + if vae_conditioning.shape[1] != SEEDVR2_LATENT_CHANNELS: + if vae_conditioning.shape[-1] == SEEDVR2_LATENT_CHANNELS: + raise ValueError( + "SeedVR2Conditioning expects SeedVR2 VAE latents in Comfy " + f"channel-first layout (B, {SEEDVR2_LATENT_CHANNELS}, T, H, W); " + f"got channel-last shape {tuple(vae_conditioning.shape)}." + ) raise ValueError( - "SeedVR2Conditioning expects SeedVR2 VAE latents in Comfy " - f"channel-first layout (B, {SEEDVR2_LATENT_CHANNELS}, T, H, W); " - f"got channel-last shape {tuple(vae_conditioning.shape)}." + "SeedVR2Conditioning expects SeedVR2 VAE latents with " + f"{SEEDVR2_LATENT_CHANNELS} channels; got shape {tuple(vae_conditioning.shape)}." ) vae_conditioning = vae_conditioning.movedim(1, -1).contiguous() - model_patcher = model - model = _resolve_seedvr2_diffusion_model(model_patcher) + model = _resolve_seedvr2_diffusion_model(model) pos_cond = model.positive_conditioning neg_cond = model.negative_conditioning - condition = torch.stack([get_conditions(c, c) for c in vae_conditioning]) + mask = vae_conditioning.new_ones(vae_conditioning.shape[:-1] + (1,)) + condition = torch.cat((vae_conditioning, mask), dim=-1) condition = condition.movedim(-1, 1) - latent = vae_conditioning.movedim(-1, 1) - - latent = latent.reshape(latent.shape[0], latent.shape[1] * latent.shape[2], latent.shape[3], latent.shape[4]) - condition = condition.reshape(condition.shape[0], condition.shape[1] * condition.shape[2], condition.shape[3], condition.shape[4]) negative = [[neg_cond.unsqueeze(0), {"condition": condition}]] positive = [[pos_cond.unsqueeze(0), {"condition": condition}]] - return io.NodeOutput(model_patcher, positive, negative, {"samples": latent}) + return io.NodeOutput(positive, negative) class SeedVRExtension(ComfyExtension): @override diff --git a/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py b/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py index d36e50428..045502b5b 100644 --- a/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py +++ b/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py @@ -1,20 +1,15 @@ -"""Consolidated SeedVR2 conditioning and refactor regression tests. - -Merges the prior test_seedvr2_refactor_nodes.py and -test_seedvr_conditioning_hardening.py modules. Refactor tests use the -top-level comfy_extras.nodes_seedvr import; conditioning-hardening tests -use _import_nodes_seedvr_isolated() for sys.modules isolation when -mocking comfy.model_management. -""" +"""SeedVR2 conditioning node regression tests.""" import importlib import sys from unittest.mock import MagicMock +import pytest import torch import torch.nn as nn from comfy.cli_args import args as cli_args +from comfy.ldm.seedvr.constants import SEEDVR2_LATENT_CHANNELS if not torch.cuda.is_available(): cli_args.cpu = True @@ -79,21 +74,18 @@ def _import_nodes_seedvr_isolated(): class _Rope(nn.Module): - """Minimal RoPE stub exposing a `freqs` parameter.""" def __init__(self): super().__init__() self.freqs = nn.Parameter(torch.zeros(4)) class _Block(nn.Module): - """Minimal transformer block stub holding a `_Rope`.""" def __init__(self): super().__init__() self.rope = _Rope() class _DiffusionModel(nn.Module): - """Stub diffusion model with N blocks and pos/neg conditioning buffers.""" def __init__(self, n_blocks=3, conditioning_dtype=torch.float32): super().__init__() self.blocks = nn.ModuleList([_Block() for _ in range(n_blocks)]) @@ -102,18 +94,16 @@ class _DiffusionModel(nn.Module): class _ModelInner: - """Inner model wrapper exposing `.diffusion_model`.""" def __init__(self, diffusion_model): self.diffusion_model = diffusion_model class _ModelPatcher: - """ModelPatcher stub exposing `.model._ModelInner`.""" def __init__(self, diffusion_model): self.model = _ModelInner(diffusion_model) -def test_seedvr2_conditioning_schema_exposes_model_passthrough_output(): +def test_seedvr2_conditioning_schema_exposes_conditioning_outputs(): nodes_seedvr, restore = _import_nodes_seedvr_isolated() try: schema = nodes_seedvr.SeedVR2Conditioning.define_schema() @@ -123,37 +113,50 @@ def test_seedvr2_conditioning_schema_exposes_model_passthrough_output(): ] assert schema.inputs[1].display_name == "latent" assert [output.display_name for output in schema.outputs] == [ - "model", "positive", "negative", - "latent", ] finally: restore() -def test_seedvr2_conditioning_returns_packed_input_latent_deterministically(): +def test_seedvr2_conditioning_rejects_wrong_latent_channels(): + nodes_seedvr, restore = _import_nodes_seedvr_isolated() + try: + patcher = _ModelPatcher(_DiffusionModel()) + vae_conditioning = {"samples": torch.zeros(1, 8, 2, 2, 2)} + + with pytest.raises(ValueError, match=f"{SEEDVR2_LATENT_CHANNELS} channels"): + nodes_seedvr.SeedVR2Conditioning.execute(patcher, vae_conditioning) + finally: + restore() + + +def test_seedvr2_conditioning_returns_conditioning_deterministically(): nodes_seedvr, restore = _import_nodes_seedvr_isolated() try: diffusion_model = _DiffusionModel() patcher = _ModelPatcher(diffusion_model) - samples = torch.arange(1, 25, dtype=torch.float32).reshape(1, 2, 3, 2, 2) + samples = torch.arange( + 1, + 1 + SEEDVR2_LATENT_CHANNELS * 3 * 2 * 2, + dtype=torch.float32, + ).reshape(1, SEEDVR2_LATENT_CHANNELS, 3, 2, 2) vae_conditioning = {"samples": samples} - _, first_positive, first_negative, first_latent = ( + first_positive, first_negative = ( nodes_seedvr.SeedVR2Conditioning.execute( patcher, vae_conditioning, ) ) - _, second_positive, second_negative, second_latent = ( + second_positive, second_negative = ( nodes_seedvr.SeedVR2Conditioning.execute( patcher, vae_conditioning, ) ) - expected_latent = samples.reshape(1, 6, 2, 2) channel_last = samples.movedim(1, -1).contiguous() expected_condition = torch.cat( [ @@ -161,10 +164,8 @@ def test_seedvr2_conditioning_returns_packed_input_latent_deterministically(): torch.ones((*channel_last.shape[:-1], 1)), ], dim=-1, - ).movedim(-1, 1).reshape(1, 9, 2, 2) + ).movedim(-1, 1) - assert torch.equal(first_latent["samples"], expected_latent) - assert torch.equal(second_latent["samples"], expected_latent) assert torch.equal( first_positive[0][1]["condition"], expected_condition, diff --git a/tests-unit/comfy_test/model_detection_test.py b/tests-unit/comfy_test/model_detection_test.py index 587c393c9..192fdbfe5 100644 --- a/tests-unit/comfy_test/model_detection_test.py +++ b/tests-unit/comfy_test/model_detection_test.py @@ -201,6 +201,17 @@ class TestModelDetection: del sd["positive_conditioning"] assert model_config_from_unet_config(unet_config, sd) is None + def test_seedvr2_model_match_normalizes_num_heads(self): + sd = _make_seedvr2_7b_shared_mm_sd() + unet_config = detect_unet_config(sd, "") + unet_config["num_heads"] = unet_config.pop("heads") + + model_config = model_config_from_unet_config(unet_config, sd) + + assert type(model_config).__name__ == "SeedVR2" + assert model_config.unet_config["heads"] == 24 + assert "num_heads" not in model_config.unet_config + def test_seedvr2_model_match_accepts_full_checkpoint_prefix(self): sd = _add_model_diffusion_prefix(_make_seedvr2_7b_shared_mm_sd()) diff --git a/tests-unit/comfy_test/seedvr_vae_forward_test.py b/tests-unit/comfy_test/seedvr_vae_forward_test.py index d4af4c2b1..7ea7a143e 100644 --- a/tests-unit/comfy_test/seedvr_vae_forward_test.py +++ b/tests-unit/comfy_test/seedvr_vae_forward_test.py @@ -1,22 +1,6 @@ -"""Regression: ``comfy.ldm.seedvr.vae.VideoAutoencoderKL.forward`` must -honor the actual tensor/tuple return contract of ``encode()`` and -``decode_()`` and must NOT dereference diffusers-style ``.latent_dist`` -or ``.sample`` attributes on those returns. - -The pre-fix body raised ``AttributeError: 'Tensor' object has no -attribute 'latent_dist'`` for ``mode in {"encode", "all"}`` and -``AttributeError: 'VideoAutoencoderKL' object has no attribute 'decode'`` -for ``mode == "decode"`` (the class only defines ``decode_`` with a -trailing underscore). The post-fix body unwraps the optional one-element -tuple shape that ``return_dict=False`` produces and returns the tensor -directly. - -Tests construct a stub subclass of ``VideoAutoencoderKL`` that bypasses -the heavy ``__init__`` via ``torch.nn.Module.__init__(self)`` and -overrides ``encode``/``decode_`` with known tensors so the contract can -be probed without loading any real VAE weights. -""" +"""Regression tests for the SeedVR2 VAE forward return contract.""" +import pytest import torch import torch.nn as nn @@ -25,13 +9,13 @@ from comfy.cli_args import args as cli_args if not torch.cuda.is_available(): cli_args.cpu = True -from comfy.ldm.seedvr.vae import VideoAutoencoderKL # noqa: E402 +from comfy.ldm.seedvr.vae import SEEDVR2_LATENT_CHANNELS, VideoAutoencoderKL # noqa: E402 -_LATENT_SHAPE = (1, 16, 2, 2, 2) +_LATENT_SHAPE = (1, SEEDVR2_LATENT_CHANNELS, 2, 2, 2) _DECODED_SHAPE = (1, 3, 5, 16, 16) _INPUT_ENCODE_SHAPE = (1, 3, 5, 16, 16) -_INPUT_DECODE_SHAPE = (1, 16, 2, 2, 2) +_INPUT_DECODE_SHAPE = _LATENT_SHAPE class _StubVAE(VideoAutoencoderKL): @@ -64,8 +48,6 @@ def test_forward_decode_returns_tensor(): class _TupleReturningStubVAE(VideoAutoencoderKL): - """Stub whose ``encode``/``decode_`` return the ``(tensor,)`` tuple of ``return_dict=False``, exercising the unwrap branch of ``VideoAutoencoderKL.forward``.""" - def __init__(self): nn.Module.__init__(self) self._encode_tensor = torch.zeros(*_LATENT_SHAPE) @@ -84,3 +66,9 @@ def test_forward_all_unwraps_one_tuple_at_each_step(): result = vae.forward(x, mode="all") assert type(result) is torch.Tensor assert result.shape == torch.Size(_DECODED_SHAPE) + + +def test_forward_rejects_unknown_mode(): + vae = _StubVAE() + with pytest.raises(ValueError, match="Unknown SeedVR2 VAE forward mode"): + vae.forward(torch.zeros(*_INPUT_ENCODE_SHAPE), mode="bogus") diff --git a/tests-unit/comfy_test/test_seedvr2_dtype.py b/tests-unit/comfy_test/test_seedvr2_dtype.py index f03c0406c..8e08b6dde 100644 --- a/tests-unit/comfy_test/test_seedvr2_dtype.py +++ b/tests-unit/comfy_test/test_seedvr2_dtype.py @@ -41,8 +41,9 @@ def test_seedvr2_text_conditioning_accepts_cfg1_single_branch(): def test_seedvr2_vae_decode_memory_covers_full_frame_lab_transfer(): wrapper = seedvr_vae.VideoAutoencoderKLWrapper.__new__(seedvr_vae.VideoAutoencoderKLWrapper) - estimate = wrapper.comfy_memory_used_decode((1, 16, 26, 120, 160)) - old_estimate = 16 * 120 * 160 * (4 * 8 * 8) * 2 + latent_channels = seedvr_vae.SEEDVR2_LATENT_CHANNELS + estimate = wrapper.comfy_memory_used_decode((1, latent_channels, 26, 120, 160)) + old_estimate = latent_channels * 120 * 160 * (4 * 8 * 8) * 2 assert estimate == 101 * 960 * 1280 * 160 assert estimate > 15 * 1024 ** 3 diff --git a/tests-unit/comfy_test/test_seedvr2_internals.py b/tests-unit/comfy_test/test_seedvr2_internals.py index 966e9465d..fe4bde1c4 100644 --- a/tests-unit/comfy_test/test_seedvr2_internals.py +++ b/tests-unit/comfy_test/test_seedvr2_internals.py @@ -1,16 +1,4 @@ -"""Consolidated SeedVR2 internals regression tests. - -Sources (all merged verbatim, helper names disambiguated where colliding): - - * GroupNorm limit gate — causal_norm_wrapper at vae.py:509 must compare - memory_occupy against get_norm_limit(), not float('inf'). - * SeedVR2 variable-length attention split-loop contract. - -Pre-import CPU-only guard is required because comfy.ldm.seedvr.model and -comfy.ldm.modules.attention transitively pull in comfy.model_management, -which probes torch.cuda.current_device() at import time unless args.cpu is -set first. -""" +"""SeedVR2 internals regression tests.""" from __future__ import annotations @@ -35,10 +23,6 @@ from comfy.ldm.seedvr.vae import ( # noqa: E402 from comfy.ldm.seedvr.attention import var_attention_optimized_split # noqa: E402 -# --------------------------------------------------------------------------- -# GroupNorm limit tests (test_seedvr_groupnorm_limit.py) -# --------------------------------------------------------------------------- - _NUM_CHANNELS = 8 _NUM_GROUPS = 4 _TENSOR_SHAPE = (1, 8, 2, 4, 4) @@ -89,10 +73,6 @@ def test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path(groupnorm_cls): set_norm_limit(None) -# --------------------------------------------------------------------------- -# SeedVR2 var_attention split-loop tests -# --------------------------------------------------------------------------- - def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypatch): dim = 8 heads = 2 @@ -140,18 +120,8 @@ def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypa assert call["heads"] == heads assert call["skip_reshape"] is True assert call["skip_output_reshape"] is True - torch.testing.assert_close( - call["cu_seqlens_q"], - torch.tensor([0, 7, 14], dtype=torch.int32), - rtol=0, - atol=0, - ) - torch.testing.assert_close( - call["cu_seqlens_k"], - torch.tensor([0, 7, 14], dtype=torch.int32), - rtol=0, - atol=0, - ) + assert call["cu_seqlens_q"] == [0, 7, 14] + assert call["cu_seqlens_k"] == [0, 7, 14] def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatch): @@ -160,7 +130,7 @@ def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatc q = torch.arange(30, dtype=torch.float32).reshape(5, heads, head_dim) k = q + 100 v = q + 200 - cu = torch.tensor([0, 2, 5], dtype=torch.int32) + cu = [0, 2, 5] calls = [] def fake_optimized_attention(q_arg, k_arg, v_arg, heads_arg, **kwargs): @@ -197,20 +167,3 @@ def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatc assert all(call["kwargs"]["skip_output_reshape"] is True for call in calls) torch.testing.assert_close(out, q + v, rtol=0, atol=0) - -def test_var_attention_optimized_split_rejects_bad_offsets(): - q = torch.randn(5, 2, 3) - cu_bad = torch.tensor([0, 2, 6], dtype=torch.int32) - cu_ok = torch.tensor([0, 2, 5], dtype=torch.int32) - - with pytest.raises(ValueError, match="cu_seqlens_q does not match token count"): - var_attention_optimized_split( - q, - q, - q, - 2, - cu_bad, - cu_ok, - skip_reshape=True, - skip_output_reshape=True, - ) diff --git a/tests-unit/comfy_test/test_seedvr2_model.py b/tests-unit/comfy_test/test_seedvr2_model.py index 06b2f1564..421e98e43 100644 --- a/tests-unit/comfy_test/test_seedvr2_model.py +++ b/tests-unit/comfy_test/test_seedvr2_model.py @@ -1,17 +1,10 @@ -"""Consolidated SeedVR2 model/graph/forward regression tests. - -Merged from: -- seedvr_model_test.py -- test_seedvr_7b_final_block_text_path.py -- test_seedvr_forward_no_device_cast.py -- test_seedvr_latent_format.py -- test_seedvr2_vae_graph_boundaries.py -""" +"""SeedVR2 model, latent-format, and VAE graph regression tests.""" from __future__ import annotations from unittest.mock import MagicMock +import pytest import torch from torch import nn @@ -22,7 +15,6 @@ if not torch.cuda.is_available(): import comfy # noqa: E402 import comfy.latent_formats # noqa: E402 -import comfy.ldm.seedvr.model # noqa: E402 import comfy.ldm.seedvr.model as seedvr_model # noqa: E402 import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402 import comfy.model_management # noqa: E402 @@ -33,9 +25,7 @@ import nodes as nodes_mod # noqa: E402 from comfy.ldm.seedvr.model import NaDiT # noqa: E402 -# --------------------------------------------------------------------------- -# Helpers from seedvr_model_test.py -# --------------------------------------------------------------------------- +_LATENT_CHANNELS = seedvr_vae_mod.SEEDVR2_LATENT_CHANNELS def _make_standin(positive_conditioning): @@ -51,11 +41,6 @@ def _make_standin(positive_conditioning): return _StandIn() -# --------------------------------------------------------------------------- -# Helpers from test_seedvr_7b_final_block_text_path.py -# --------------------------------------------------------------------------- - - class _StubModule(nn.Module): def __init__(self, *args, **kwargs): super().__init__() @@ -88,11 +73,6 @@ def _capture_last_layer_flags(monkeypatch, vid_dim: int, txt_in_dim: int) -> lis return flags -# --------------------------------------------------------------------------- -# Helpers from test_seedvr_latent_format.py -# --------------------------------------------------------------------------- - - class _Model: def __init__(self, latent_format): self._latent_format = latent_format @@ -102,11 +82,6 @@ class _Model: return self._latent_format -# --------------------------------------------------------------------------- -# Helpers from test_seedvr2_vae_graph_boundaries.py -# --------------------------------------------------------------------------- - - class _Patcher: def get_free_memory(self, device): return 1024 * 1024 * 1024 @@ -136,14 +111,14 @@ class _DecodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper): self.calls.append({"shape": tuple(z.shape), "seedvr2_tiling": seedvr2_tiling}) if z.ndim == 4: b, tc, h, w = z.shape - t = tc // 16 + t = tc // _LATENT_CHANNELS else: b, _, t, h, w = z.shape return torch.zeros(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device) def test_seedvr2_wrapper_public_encode_returns_tensor(monkeypatch): - raw_latent = torch.full((1, 16, 1, 4, 5), 2.0) + raw_latent = torch.full((1, _LATENT_CHANNELS, 1, 4, 5), 2.0) seen_shapes = [] def base_encode(self, x): @@ -159,12 +134,12 @@ def test_seedvr2_wrapper_public_encode_returns_tensor(monkeypatch): latent = vae.encode(torch.zeros(1, 3, 32, 40)) assert type(latent) is torch.Tensor - assert tuple(latent.shape) == (1, 16, 4, 5) + assert tuple(latent.shape) == (1, _LATENT_CHANNELS, 4, 5) assert seen_shapes == [(1, 3, 1, 32, 40)] def test_seedvr2_wrapper_private_encode_helper_keeps_raw_latent(monkeypatch): - raw_latent = torch.full((1, 16, 1, 4, 5), 3.0) + raw_latent = torch.full((1, _LATENT_CHANNELS, 1, 4, 5), 3.0) def base_encode(self, x): return raw_latent.to(device=x.device, dtype=x.dtype) @@ -177,8 +152,8 @@ def test_seedvr2_wrapper_private_encode_helper_keeps_raw_latent(monkeypatch): latent, raw = vae._encode_with_raw_latent(torch.zeros(1, 3, 32, 40)) - assert tuple(latent.shape) == (1, 16, 4, 5) - assert tuple(raw.shape) == (1, 16, 1, 4, 5) + assert tuple(latent.shape) == (1, _LATENT_CHANNELS, 4, 5) + assert tuple(raw.shape) == (1, _LATENT_CHANNELS, 1, 4, 5) assert torch.equal(raw, raw_latent) @@ -188,7 +163,7 @@ def _make_vae(wrapper): vae.device = torch.device("cpu") vae.output_device = torch.device("cpu") vae.vae_dtype = torch.float32 - vae.latent_channels = 16 + vae.latent_channels = _LATENT_CHANNELS vae.latent_dim = 3 vae.downscale_ratio = (lambda a: max(0, (a + 3) // 4), 8, 8) vae.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) @@ -212,13 +187,7 @@ def _make_vae(wrapper): return vae -# --------------------------------------------------------------------------- -# Tests from seedvr_model_test.py -# --------------------------------------------------------------------------- - - def test_missing_context_falls_back_to_positive_buffer(): - """``context is None`` falls back to the registered ``positive_conditioning`` buffer and runs to completion.""" pos_buffer = torch.full((58, 5120), 7.0) standin = _make_standin(pos_buffer) txt, txt_shape = standin._resolve_text_conditioning(None) @@ -231,11 +200,6 @@ def test_missing_context_falls_back_to_positive_buffer(): assert txt_shape[0, 0].item() == 58 -# --------------------------------------------------------------------------- -# Tests from test_seedvr_7b_final_block_text_path.py -# --------------------------------------------------------------------------- - - def test_seedvr2_7b_keeps_final_block_text_path(monkeypatch): assert _capture_last_layer_flags(monkeypatch, vid_dim=3072, txt_in_dim=3072) == [ False, @@ -268,43 +232,49 @@ def test_seedvr2_7b_rope3d_matches_wrapper_oracle(): torch.testing.assert_close(actual_k, expected_k, rtol=0, atol=0) -# --------------------------------------------------------------------------- -# Tests from test_seedvr_latent_format.py -# --------------------------------------------------------------------------- +def test_seedvr2_forward_requires_conditioning_latents(): + model = NaDiT.__new__(NaDiT) + x = torch.zeros(1, _LATENT_CHANNELS, 1, 4, 5) + + with pytest.raises(ValueError, match="requires conditioning latents"): + NaDiT.forward(model, x, timestep=torch.tensor([1.0]), context=None) -def test_seedvr2_latent_format_uses_16_channels_without_3d_empty_latent_expansion(): +def test_seedvr2_latent_format_uses_native_video_latent_shape(): latent_format = comfy.latent_formats.SeedVR2() latent_image = torch.zeros(1, 1, 4, 5) fixed = comfy.sample.fix_empty_latent_channels(_Model(latent_format), latent_image) - assert latent_format.latent_channels == 16 - assert latent_format.latent_dimensions == 2 - assert fixed.shape == (1, 16, 4, 5) + assert latent_format.latent_channels == _LATENT_CHANNELS + assert latent_format.latent_dimensions == 3 + assert fixed.shape == (1, _LATENT_CHANNELS, 1, 4, 5) -# --------------------------------------------------------------------------- -# Tests from test_seedvr2_vae_graph_boundaries.py -# --------------------------------------------------------------------------- +def test_seedvr2_model_requires_native_5d_latent(): + latent = torch.zeros(1, _LATENT_CHANNELS, 2, 4, 5) + assert NaDiT._check_seedvr2_video_latent(latent, _LATENT_CHANNELS, "latent") is latent + + with pytest.raises(ValueError, match="5-D native latent"): + NaDiT._check_seedvr2_video_latent(torch.zeros(1, _LATENT_CHANNELS * 2, 4, 5), _LATENT_CHANNELS, "latent") def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeypatch): monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None) - encoded = torch.full((1, 16, 2, 4, 5), 2.0) + encoded = torch.full((1, _LATENT_CHANNELS, 2, 4, 5), 2.0) vae = _make_vae(_EncodeWrapper(encoded)) pixels = torch.zeros(1, 5, 32, 40, 3) node_output = nodes_mod.VAEEncode().encode(vae, pixels)[0] node_latent = node_output["samples"] assert set(node_output) == {"samples"} - assert tuple(node_latent.shape) == (1, 16, 2, 4, 5) + assert tuple(node_latent.shape) == (1, _LATENT_CHANNELS, 2, 4, 5) assert node_latent.dtype == torch.float32 assert node_latent.stride()[-1] == 1 - assert torch.equal(node_latent, torch.full_like(node_latent, 2.0 * 0.9152)) + assert torch.equal(node_latent, torch.full_like(node_latent, 2.0 * seedvr_vae_mod.BYTEDANCE_VAE_SCALING_FACTOR)) - tiled = torch.full((1, 16, 2, 4, 5), 3.0) + tiled = torch.full((1, _LATENT_CHANNELS, 2, 4, 5), 3.0) monkeypatch.setattr(seedvr_vae_mod, "tiled_vae", MagicMock(return_value=tiled)) tiled_output = nodes_mod.VAEEncodeTiled().encode( vae, @@ -316,9 +286,9 @@ def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeyp )[0] tiled_latent = tiled_output["samples"] assert set(tiled_output) == {"samples"} - assert tuple(tiled_latent.shape) == (1, 16, 2, 4, 5) + assert tuple(tiled_latent.shape) == (1, _LATENT_CHANNELS, 2, 4, 5) assert tiled_latent.dtype == torch.float32 - assert torch.equal(tiled_latent, torch.full_like(tiled_latent, 3.0 * 0.9152)) + assert torch.equal(tiled_latent, torch.full_like(tiled_latent, 3.0 * seedvr_vae_mod.BYTEDANCE_VAE_SCALING_FACTOR)) def test_vaedecode_tiled_spatial_applies_temporal_discarded(monkeypatch): @@ -327,7 +297,7 @@ def test_vaedecode_tiled_spatial_applies_temporal_discarded(monkeypatch): nodes_mod.VAEDecodeTiled().decode( vae, - {"samples": torch.zeros(1, 16, 2, 4, 5)}, + {"samples": torch.zeros(1, _LATENT_CHANNELS, 2, 4, 5)}, tile_size=512, overlap=64, temporal_size=16, @@ -339,7 +309,7 @@ def test_vaedecode_tiled_spatial_applies_temporal_discarded(monkeypatch): # knobs are no-ops at the wrapper. assert vae.first_stage_model.calls == [ { - "shape": (1, 16, 2, 4, 5), + "shape": (1, _LATENT_CHANNELS, 2, 4, 5), "seedvr2_tiling": { "enable_tiling": True, "tile_size": (512, 512), diff --git a/tests-unit/comfy_test/test_seedvr2_vae_decode.py b/tests-unit/comfy_test/test_seedvr2_vae_decode.py index ea9f978f3..c486b9195 100644 --- a/tests-unit/comfy_test/test_seedvr2_vae_decode.py +++ b/tests-unit/comfy_test/test_seedvr2_vae_decode.py @@ -13,6 +13,9 @@ import comfy.ldm.seedvr.vae as vae_mod # noqa: E402 from comfy_extras import nodes_seedvr # noqa: E402 +_LATENT_CHANNELS = vae_mod.SEEDVR2_LATENT_CHANNELS + + def _make_wrapper() -> vae_mod.VideoAutoencoderKLWrapper: wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__( vae_mod.VideoAutoencoderKLWrapper @@ -40,7 +43,7 @@ def _decode_with_patches(wrapper, z): def test_decode_b2_t3_multi_frame_batch_unchanged(): wrapper = _make_wrapper() - out = _decode_with_patches(wrapper, torch.zeros(2, 16 * 3, 2, 2)) + out = _decode_with_patches(wrapper, torch.zeros(2, _LATENT_CHANNELS * 3, 2, 2)) assert tuple(out.shape) == (2, 3, 3, 16, 16) @@ -62,17 +65,17 @@ def test_seedvr2_wrapper_decode_accepts_5d_channel_first_latents_without_preproc wrapper = _Wrapper() with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub): - out = wrapper.decode(torch.zeros(1, 16, 2, 4, 5)) + out = wrapper.decode(torch.zeros(1, _LATENT_CHANNELS, 2, 4, 5)) assert tuple(out.shape) == (1, 3, 2, 32, 40) - assert wrapper.calls == [(1, 16, 2, 4, 5)] + assert wrapper.calls == [(1, _LATENT_CHANNELS, 2, 4, 5)] def test_seedvr2_wrapper_decode_rejects_wrong_rank_latents(): wrapper = _Wrapper() with pytest.raises(RuntimeError, match=r"latent input must be 4-D collapsed .* or 5-D"): - wrapper.decode(torch.zeros(1, 16, 4)) + wrapper.decode(torch.zeros(1, _LATENT_CHANNELS, 4)) def _t_padded(t_in: int) -> int: diff --git a/tests-unit/comfy_test/test_seedvr2_vae_tiled.py b/tests-unit/comfy_test/test_seedvr2_vae_tiled.py index 0d3c97e4a..33c6d8915 100644 --- a/tests-unit/comfy_test/test_seedvr2_vae_tiled.py +++ b/tests-unit/comfy_test/test_seedvr2_vae_tiled.py @@ -16,9 +16,7 @@ import comfy.sd as sd_mod # noqa: E402 from comfy.ldm.seedvr.vae import MemoryState, tiled_vae # noqa: E402 -# --------------------------------------------------------------------------- -# From test_seedvr_vae_tiled_decode_latent_min_size_override.py -# --------------------------------------------------------------------------- +_LATENT_CHANNELS = seedvr_vae_mod.SEEDVR2_LATENT_CHANNELS def test_runtime_decode_zero_temporal_size_disables_slicing_for_call(): @@ -44,7 +42,7 @@ def test_runtime_decode_zero_temporal_size_disables_slicing_for_call(): return torch.zeros((b, 3, d, h * 8, w * 8), dtype=z.dtype) vae = StubVAEModel() - z = torch.zeros((1, 16, 5, 8, 8), dtype=torch.float32) + z = torch.zeros((1, _LATENT_CHANNELS, 5, 8, 8), dtype=torch.float32) tiled_vae( z, @@ -61,11 +59,6 @@ def test_runtime_decode_zero_temporal_size_disables_slicing_for_call(): assert vae.slicing_latent_min_size == 2 -# --------------------------------------------------------------------------- -# From test_seedvr_vae_tiled_encode_runt_slice_override.py -# --------------------------------------------------------------------------- - - def test_zero_temporal_size_preserves_min_size_when_encode_raises(): class RaisingVAEModel(torch.nn.Module): def __init__(self): @@ -110,7 +103,7 @@ def test_tiled_vae_encode_uses_tensor_return_without_indexing(): def encode(self, t_chunk): self.calls.append(tuple(t_chunk.shape)) b, _, _, h, w = t_chunk.shape - return torch.ones((b, 16, 1, h // 8, w // 8), dtype=t_chunk.dtype) + return torch.ones((b, _LATENT_CHANNELS, 1, h // 8, w // 8), dtype=t_chunk.dtype) vae = TensorEncodeVAEModel() x = torch.zeros((2, 3, 1, 64, 64), dtype=torch.float32) @@ -126,12 +119,34 @@ def test_tiled_vae_encode_uses_tensor_return_without_indexing(): ) assert vae.calls == [(2, 3, 1, 64, 64)] - assert tuple(out.shape) == (2, 16, 1, 8, 8) + assert tuple(out.shape) == (2, _LATENT_CHANNELS, 1, 8, 8) -# --------------------------------------------------------------------------- -# From test_seedvr_vae_tiled_temporal_slicing.py -# --------------------------------------------------------------------------- +def test_tiled_vae_preserves_input_dtype_on_single_tile(): + class FloatOutputVAEModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.slicing_sample_min_size = 4 + self.spatial_downsample_factor = 8 + self.temporal_downsample_factor = 4 + self.device = torch.device("cpu") + self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32)) + + def encode(self, t_chunk): + b, _, _, h, w = t_chunk.shape + return torch.ones((b, _LATENT_CHANNELS, 1, h // 8, w // 8), dtype=torch.float32) + + out = tiled_vae( + torch.zeros((1, 3, 1, 64, 64), dtype=torch.float16), + FloatOutputVAEModel(), + tile_size=(64, 64), + tile_overlap=(0, 0), + temporal_size=0, + temporal_overlap=0, + encode=True, + ) + + assert out.dtype == torch.float16 class _SlicingDecodeVAE(nn.Module): @@ -164,7 +179,10 @@ class _SlicingDecodeVAE(nn.Module): def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size(): vae = _SlicingDecodeVAE(slicing_latent_min_size=2) - z = torch.arange(1 * 16 * 5 * 8 * 8, dtype=torch.float32).reshape(1, 16, 5, 8, 8) + z = torch.arange( + _LATENT_CHANNELS * 5 * 8 * 8, + dtype=torch.float32, + ).reshape(1, _LATENT_CHANNELS, 5, 8, 8) tiled_vae( z, @@ -199,16 +217,11 @@ def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size(): return torch.zeros(1, 3, 1, 16, 16) with patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae): - wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling=seedvr2_tiling) + wrapper.decode(torch.zeros(1, _LATENT_CHANNELS, 2, 2), seedvr2_tiling=seedvr2_tiling) assert captured["temporal_overlap"] == 7 -# --------------------------------------------------------------------------- -# From test_vae_decode_tiled_dispatcher_seedvr2_4d.py -# --------------------------------------------------------------------------- - - def _force_oom(*a, **k): raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test") @@ -256,10 +269,10 @@ def _dispatch(vae, samples, seedvr2_call, generic_call, patch_wrapper_decode): def test_4d_seedvr2_latent_routes_to_owned_decode_tiled(): wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__( seedvr_vae_mod.VideoAutoencoderKLWrapper) - vae = _make_vae(wrapper, latent_channels=16, latent_dim=3) + vae = _make_vae(wrapper, latent_channels=_LATENT_CHANNELS, latent_dim=3) seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64)) generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64)) - _dispatch(vae, torch.zeros(1, 16 * 3, 8, 8), seedvr2_call, generic_call, True) + _dispatch(vae, torch.zeros(1, _LATENT_CHANNELS * 3, 8, 8), seedvr2_call, generic_call, True) assert seedvr2_call.call_count == 1 assert generic_call.call_count == 0 @@ -275,11 +288,6 @@ def test_4d_non_seedvr2_latent_still_routes_to_generic_decode_tiled(): assert seedvr2_call.call_count == 0 -# --------------------------------------------------------------------------- -# From test_vae_encode_tiled_fallback_dispatcher_seedvr2.py -# --------------------------------------------------------------------------- - - def _populate_common_vae_attrs_fallback(vae): vae.patcher = MagicMock() vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024) @@ -291,7 +299,7 @@ def _populate_common_vae_attrs_fallback(vae): vae.upscale_ratio = 8 vae.upscale_index_formula = None vae.output_channels = 3 - vae.latent_channels = 16 + vae.latent_channels = _LATENT_CHANNELS vae.latent_dim = 3 vae.downscale_ratio = 8 vae.downscale_index_formula = None @@ -334,8 +342,8 @@ def test_seedvr2_3d_routes_to_owned_encode_tiled_on_oom(): vae = _make_seedvr2_vae_fallback() pixel_samples = torch.zeros((1, 8, 64, 64, 3)) - seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) - generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) + seedvr2_call = MagicMock(return_value=torch.zeros(1, _LATENT_CHANNELS, 2, 8, 8)) + generic_call = MagicMock(return_value=torch.zeros(1, _LATENT_CHANNELS, 2, 8, 8)) with patch.object(sd_mod.model_management, "raise_non_oom", lambda e: None), \ @@ -363,7 +371,7 @@ def test_non_seedvr2_encode_tiled_3d_default_overlap_is_concrete(): vae = _make_non_seedvr2_vae_fallback() vae.downscale_ratio = (lambda a: max(1, a // 4), 8, 8) vae.upscale_ratio = (lambda a: a * 4, 8, 8) - generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) + generic_call = MagicMock(return_value=torch.zeros(1, _LATENT_CHANNELS, 2, 8, 8)) pixel_samples = torch.zeros((1, 8, 64, 64, 3)) with patch.object(sd_mod.model_management, "load_models_gpu", From 6d72960989dc976a11118f3067712dbfe303c87a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 2 Jul 2026 23:04:52 -0400 Subject: [PATCH 12/12] Fix ruff. --- comfy/ldm/seedvr/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index 872140558..772ec91c5 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -674,7 +674,6 @@ class NaSwinAttention(NaMMAttention): else: vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) - txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count)) txt_len_win_list = cache_win( "txt_len_list", lambda: [txt_len for txt_len, window_count in zip(txt_len.tolist(), window_count_list) for _ in range(window_count)],