Add SeedVR2 model support

This commit is contained in:
John Pollock 2026-06-11 10:39:39 -05:00
parent 6d18f4adac
commit cd18c4460a
9 changed files with 1732 additions and 3 deletions

View File

@ -779,6 +779,9 @@ class ACEAudio(LatentFormat):
latent_channels = 8 latent_channels = 8
latent_dimensions = 2 latent_dimensions = 2
class SeedVR2(LatentFormat):
latent_channels = 16
class ACEAudio15(LatentFormat): class ACEAudio15(LatentFormat):
latent_channels = 64 latent_channels = 64
latent_dimensions = 1 latent_dimensions = 1

View File

@ -22,7 +22,7 @@ def torch_cat_if_needed(xl, dim):
else: else:
return None return None
def get_timestep_embedding(timesteps, embedding_dim): def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1):
""" """
This matches the implementation in Denoising Diffusion Probabilistic Models: This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq. From Fairseq.
@ -33,11 +33,13 @@ def get_timestep_embedding(timesteps, embedding_dim):
assert len(timesteps.shape) == 1 assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2 half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1) emb = math.log(10000) / (half_dim - downscale_freq_shift)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device) emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :] emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
if embedding_dim % 2 == 1: # zero pad if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0,1,0,0)) emb = torch.nn.functional.pad(emb, (0,1,0,0))
return emb return emb

View File

@ -0,0 +1,77 @@
import torch
from comfy.ldm.modules import attention as _attention
def _var_attention_qkv(q, k, v, heads, skip_reshape):
if skip_reshape:
return q, k, v, q.shape[-1]
total_tokens, embed_dim = q.shape
head_dim = embed_dim // heads
return (
q.view(total_tokens, heads, head_dim),
k.view(k.shape[0], heads, head_dim),
v.view(v.shape[0], heads, head_dim),
head_dim,
)
def _var_attention_output(out, heads, head_dim, skip_output_reshape):
if skip_output_reshape:
return out
return out.reshape(-1, heads * head_dim)
def _validate_split_cu_seqlens(name, cu_seqlens, token_count):
if cu_seqlens.dtype not in (torch.int32, torch.int64):
raise ValueError(f"{name} must use an integer dtype")
if cu_seqlens.ndim != 1 or cu_seqlens.numel() < 2:
raise ValueError(f"{name} must be a 1D tensor with at least two offsets")
if cu_seqlens[0].item() != 0:
raise ValueError(f"{name} must start at 0")
if (cu_seqlens[1:] <= cu_seqlens[:-1]).any().item():
raise ValueError(f"{name} must be strictly increasing")
if cu_seqlens[-1].item() != token_count:
raise ValueError(f"{name} does not match token count")
def _split_indices(cu_seqlens):
return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long)
def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
_validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0])
_validate_split_cu_seqlens("cu_seqlens_k", cu_seqlens_k, k.shape[0])
if cu_seqlens_k[-1].item() != v.shape[0]:
raise ValueError("cu_seqlens_k does not match v token count")
q_split_indices = _split_indices(cu_seqlens_q)
k_split_indices = _split_indices(cu_seqlens_k)
q_splits = torch.tensor_split(q, q_split_indices, dim=0)
k_splits = torch.tensor_split(k, k_split_indices, dim=0)
v_splits = torch.tensor_split(v, k_split_indices, dim=0)
if len(q_splits) != len(k_splits) or len(q_splits) != len(v_splits):
raise ValueError("cu_seqlens_q and cu_seqlens_k must describe the same sequence count")
out = []
for q_i, k_i, v_i in zip(q_splits, k_splits, v_splits):
q_i = q_i.permute(1, 0, 2).unsqueeze(0)
k_i = k_i.permute(1, 0, 2).unsqueeze(0)
v_i = v_i.permute(1, 0, 2).unsqueeze(0)
out_dtype = q_i.dtype
if _attention.optimized_attention is _attention.attention_sage and q_i.dtype not in (torch.float16, torch.bfloat16):
q_i = q_i.to(torch.bfloat16)
k_i = k_i.to(torch.bfloat16)
v_i = v_i.to(torch.bfloat16)
out_i = _attention.optimized_attention(q_i, k_i, v_i, heads, skip_reshape=True, skip_output_reshape=True)
if out_i.dtype != out_dtype:
out_i = out_i.to(out_dtype)
out.append(out_i.squeeze(0).permute(1, 0, 2))
out = torch.cat(out, dim=0)
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
optimized_var_attention = var_attention_optimized_split

View File

@ -0,0 +1,72 @@
"""Named constants for the SeedVR2 integration, grouped by provenance.
Provenance prefixes:
- ``SEEDVR2_*`` - introduced by this integration (no external origin); rationale inline.
- ``BYTEDANCE_*`` - ported from the official ByteDance-Seed/SeedVR release; each cites
the upstream config/source path it was lifted from.
- unprefixed standards (``ROPE_THETA``, ``CIELAB_*``, ``D65_*``) - published literature /
ISO / CIE values; cite the standard.
"""
# --------------------------------------------------------------------------------------
# A. Progressive-sampler chunk-size law (SEEDVR2 - this integration's VRAM experiment)
# n_max(frames/chunk) = SEEDVR2_CHUNK_FRAMES_PER_GB * (free_GB - SEEDVR2_CHUNK_GB_MARGIN)
# rounded to the 4n+1 grid. Fit on 22 blocked-5090 cells, validated on a real RTX 4070
# (3b and 7b). Resolution-independent (the VAE tiling sets the wall, not the DiT).
# --------------------------------------------------------------------------------------
SEEDVR2_CHUNK_GB_MARGIN = 3 # fixed VRAM overhead before chunks scale (GiB)
SEEDVR2_CHUNK_FRAMES_PER_GB = 4 # empirical slope: pixel frames admitted per free GiB
# --------------------------------------------------------------------------------------
# B. Fork heuristics (SEEDVR2 - this integration)
# --------------------------------------------------------------------------------------
SEEDVR2_7B_VID_DIM = 3072 # runtime 3b-vs-7b sentinel; tested against vid_dim.
# (3072 is ByteDance's 7b vid_dim; the sentinel use is ours.)
SEEDVR2_OOM_BACKOFF_DIVISOR = 2 # auto-chunk OOM retry: halve the chunk and retry.
SEEDVR2_DTYPE_BYTES_FLOOR = 4 # per-element byte floor for memory math (fp32 worst case).
SEEDVR2_7B_MLP_CHUNK = 8192 # 7b MLP token-chunk to bound peak VRAM.
SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS = 4096 # partial-RoPE application token-chunk.
SEEDVR2_LATENT_CHANNELS = 16 # SeedVR2 latent channel count (== BYTEDANCE latent_channels).
SEEDVR2_COND_CHANNELS = 17 # conditioning channels = vid_in_channels(33) - latent(16).
# Color-correction memory model (fork tuning; per-frame VRAM estimate for chunk sizing)
SEEDVR2_COLOR_MEM_HEADROOM = 0.75 # fraction of free VRAM usable per color-correction chunk.
SEEDVR2_LAB_SCALE_MULTIPLIER = 13 # per-frame byte multiplier, LAB path.
SEEDVR2_WAVELET_SCALE_MULTIPLIER = 10 # per-frame byte multiplier, wavelet path.
SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 # per-frame byte multiplier, AdaIN path.
# --------------------------------------------------------------------------------------
# C. ByteDance config / source (BYTEDANCE - cite ByteDance-Seed/SeedVR)
# --------------------------------------------------------------------------------------
BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57 (scaling_factor); latent denorm.
BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 # infer.py (shifting_factor default); latent denorm shift.
BYTEDANCE_VAE_CONV_MEM_GIB = 0.5 # configs_3b/main.yaml:54 (conv_max_mem).
BYTEDANCE_VAE_NORM_MEM_GIB = 0.5 # configs_3b/main.yaml:55 (norm_max_mem).
BYTEDANCE_LOGVAR_CLAMP_MIN = -30.0 # video_vae_v3/modules/types.py:28.
BYTEDANCE_LOGVAR_CLAMP_MAX = 20.0 # video_vae_v3/modules/types.py:28.
BYTEDANCE_GN_CHUNKS_FP16 = 4 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp16).
BYTEDANCE_GN_CHUNKS_FP32 = 2 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp32).
BYTEDANCE_BLOCK_OUT_CHANNELS = (128, 256, 512, 512) # s8_c16_t4_inflation_sd3.yaml:7-11.
BYTEDANCE_SLICING_SAMPLE_MIN = 4 # s8_c16_t4_inflation_sd3.yaml:22 (slicing_sample_min_size).
BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE = 4 # infer.py:230 (temporal_downsample_factor); the 4n+1 factor.
BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE = 8 # infer.py:231 (spatial_downsample_factor).
BYTEDANCE_720P_REF_AREA = 45 * 80 # dit_v2/window.py:32 (720p reference area for window scaling).
BYTEDANCE_MAX_TEMPORAL_WINDOW = 30 # dit_v2/window.py:35 (max temporal window frames).
BYTEDANCE_ROPE_MAX_FREQ = 256 # dit_v2/rope.py:31 (pixel-RoPE max frequency).
BYTEDANCE_SINUSOIDAL_DIM = 256 # dit_3b/nadit.py:120 (timestep sinusoidal embed dim).
# --------------------------------------------------------------------------------------
# D. Published standards (cite the literature)
# --------------------------------------------------------------------------------------
ROPE_THETA = 10000 # RoPE base; Su et al., "RoFormer", arXiv:2104.09864.
# CIELAB f(t) piecewise constants and D65 white point (CIE 15 colorimetry; CIE D65).
CIELAB_DELTA = 6.0 / 29.0 # CIE 15 (delta).
CIELAB_KAPPA = (29.0 / 3.0) ** 3 # CIE 15 (kappa).
D65_WHITE_X = 0.95047 # CIE D65 standard illuminant Xn (Yn = 1).
D65_WHITE_Z = 1.08883 # CIE D65 standard illuminant Zn.
WAVELET_DECOMP_LEVELS = 5 # wavelet color-fix decomposition depth (GIMP/Krita; StableSR).
# NOTE: the sRGB<->XYZ D65 3x3 matrices (IEC 61966-2-1) remain inline in the color code and
# are named (SRGB_TO_XYZ_D65 / XYZ_TO_SRGB_D65) during the color-module extraction, where the
# exact existing coefficients move verbatim rather than being retyped here.

1487
comfy/ldm/seedvr/model.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -54,6 +54,7 @@ import comfy.ldm.pixeldit.model
import comfy.ldm.pixeldit.pid import comfy.ldm.pixeldit.pid
import comfy.ldm.ace.model import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2 import comfy.ldm.omnigen.omnigen2
import comfy.ldm.seedvr.model
import comfy.ldm.qwen_image.model import comfy.ldm.qwen_image.model
import comfy.ldm.ideogram4.model import comfy.ldm.ideogram4.model
import comfy.ldm.kandinsky5.model import comfy.ldm.kandinsky5.model
@ -929,6 +930,16 @@ class HunyuanDiT(BaseModel):
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]])) out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
return out return out
class SeedVR2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
condition = kwargs.get("condition", None)
if condition is not None:
out["condition"] = comfy.conds.CONDRegular(condition)
return out
class PixArt(BaseModel): class PixArt(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None): def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS)

View File

@ -598,6 +598,53 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
return dit_config return dit_config
if "{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix) in state_dict_keys and state_dict["{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix)].shape[1] == 3072: # seedvr2 7b
dit_config = {}
dit_config["image_model"] = "seedvr2"
dit_config["vid_dim"] = 3072
dit_config["heads"] = 24
dit_config["num_layers"] = 36
# 7B uses non-shared MMModule layout (separate ``vid.`` / ``txt.``
# submodules) at EVERY block — verified by inspecting the 7B
# state_dict at ``blocks.31.ada.txt.attn_gate`` (txt. prefix means
# ``MMModule.shared_weights=False``). Native NaDiT computes
# per-block ``shared_weights = not (i < mm_layers)``, so to keep
# every block non-shared we set ``mm_layers = num_layers``.
# Without this, blocks at index >= mm_layers (default 10) try to
# load ``blocks.N.*.all.*`` keys that don't exist in the file,
# silently miss-load → all-black output.
dit_config["mm_layers"] = 36
dit_config["norm_eps"] = 1e-5
dit_config["rope_type"] = "rope3d"
dit_config["rope_dim"] = 64
dit_config["mlp_type"] = "normal"
return dit_config
elif "{}blocks.35.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b
dit_config = {}
dit_config["image_model"] = "seedvr2"
dit_config["vid_dim"] = 3072
dit_config["heads"] = 24
dit_config["num_layers"] = 36
# This checkpoint layout carries shared ``all.`` MMModule keys.
# Preserve the historical split: the initial blocks use separate
# vid/txt modules, later blocks use shared modules.
dit_config["mm_layers"] = 10
dit_config["norm_eps"] = 1e-5
dit_config["rope_type"] = "rope3d"
dit_config["rope_dim"] = 64
dit_config["mlp_type"] = "swiglu"
return dit_config
elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b
dit_config = {}
dit_config["image_model"] = "seedvr2"
dit_config["vid_dim"] = 2560
dit_config["heads"] = 20
dit_config["num_layers"] = 32
dit_config["norm_eps"] = 1.0e-05
dit_config["mlp_type"] = "swiglu"
dit_config["vid_out_norm"] = True
return dit_config
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
dit_config = {} dit_config = {}
dit_config["image_model"] = "wan2.1" dit_config["image_model"] = "wan2.1"

View File

@ -1683,6 +1683,35 @@ class Chroma(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
class SeedVR2(supported_models_base.BASE):
unet_config = {
"image_model": "seedvr2"
}
latent_format = comfy.latent_formats.SeedVR2
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
sampling_settings = {
"shift": 1.0,
}
def set_inference_dtype(self, dtype, manual_cast_dtype, device=None):
if (
dtype == torch.float16
and manual_cast_dtype is None
and comfy.model_management.should_use_bf16(device)
):
manual_cast_dtype = torch.bfloat16
super().set_inference_dtype(dtype, manual_cast_dtype, device=device)
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SeedVR2(self, device=device)
return out
def clip_target(self, state_dict={}):
return None
class ChromaRadiance(Chroma): class ChromaRadiance(Chroma):
unet_config = { unet_config = {
"image_model": "chroma_radiance", "image_model": "chroma_radiance",
@ -2296,6 +2325,7 @@ models = [
HiDream, HiDream,
HiDreamO1, HiDreamO1,
Chroma, Chroma,
SeedVR2,
ChromaRadiance, ChromaRadiance,
ACEStep, ACEStep,
ACEStep15, ACEStep15,

View File

@ -115,7 +115,7 @@ class BASE:
replace_prefix = {"": self.vae_key_prefix[0]} replace_prefix = {"": self.vae_key_prefix[0]}
return utils.state_dict_prefix_replace(state_dict, replace_prefix) return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def set_inference_dtype(self, dtype, manual_cast_dtype): def set_inference_dtype(self, dtype, manual_cast_dtype, device=None):
self.unet_config['dtype'] = dtype self.unet_config['dtype'] = dtype
self.manual_cast_dtype = manual_cast_dtype self.manual_cast_dtype = manual_cast_dtype