Add SeedVR2 VAE support

This commit is contained in:
John Pollock 2026-06-11 10:39:54 -05:00
parent cd18c4460a
commit a7ea0c2773
2 changed files with 1912 additions and 20 deletions

1807
comfy/ldm/seedvr/vae.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,3 +1,4 @@
import inspect
import json
import torch
from enum import Enum
@ -16,6 +17,7 @@ import comfy.ldm.cosmos.vae
import comfy.ldm.wan.vae
import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.seedvr.vae
import comfy.ldm.triposplat.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.cogvideo.vae
@ -467,8 +469,10 @@ class CLIP:
class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
is_seedvr2_vae = "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd
if not is_seedvr2_vae and 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
if metadata is None or metadata.get("keep_diffusers_format") != "true":
sd = diffusers_convert.convert_vae_state_dict(sd)
if model_management.is_amd():
VAE_KL_MEM_RATIO = 2.73
@ -540,6 +544,20 @@ class VAE:
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2
self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper()
self.latent_channels = 16
self.latent_dim = 3
self.disable_offload = True
self.memory_used_decode = lambda shape, dtype: self.first_stage_model.comfy_memory_used_decode(shape)
self.memory_used_encode = lambda shape, dtype: (max(shape[2], 5) * shape[3] * shape[4] * 64) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.process_input = lambda image: image * 2.0 - 1.0
self.crop_input = False
elif "decoder.conv_in.weight" in sd:
if sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
@ -1006,6 +1024,10 @@ class VAE:
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
def _decode_tiled_owned(self, samples, **kwargs):
out = self.first_stage_model.decode_tiled(samples.to(self.vae_dtype).to(self.device), **kwargs)
return self.process_output(out.to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True))
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
@ -1042,6 +1064,11 @@ class VAE:
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def _encode_tiled_owned(self, pixel_samples, **kwargs):
x = self.process_input(pixel_samples).to(self.vae_dtype).to(self.device)
out = self.first_stage_model.encode_tiled(x, **kwargs)
return out.to(device=self.output_device, dtype=self.vae_output_dtype())
def decode(self, samples_in, vae_options={}):
self.throw_exception_if_invalid()
pixel_samples = None
@ -1089,11 +1116,19 @@ class VAE:
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
elif dims == 2:
pixel_samples = self.decode_tiled_(samples_in)
if getattr(self.first_stage_model, "comfy_handles_tiling", False):
tile = 256 // self.spacial_compression_decode()
overlap = tile // 4
pixel_samples = self._decode_tiled_owned(samples_in, tile_x=tile, tile_y=tile, overlap=overlap)
else:
pixel_samples = self.decode_tiled_(samples_in)
elif dims == 3:
tile = 256 // self.spacial_compression_decode()
overlap = tile // 4
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
if getattr(self.first_stage_model, "comfy_handles_tiling", False):
pixel_samples = self._decode_tiled_owned(samples_in, tile_x=tile, tile_y=tile, overlap=overlap)
else:
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples
@ -1112,7 +1147,20 @@ class VAE:
args["overlap"] = overlap
with model_management.cuda_device_context(self.device):
if dims == 1 or self.extra_1d_channel is not None:
if getattr(self.first_stage_model, "comfy_handles_tiling", False) and dims in (2, 3):
tiled_args = {}
if tile_x is not None:
tiled_args["tile_x"] = tile_x
if tile_y is not None:
tiled_args["tile_y"] = tile_y
if overlap is not None:
tiled_args["overlap"] = overlap
if tile_t is not None:
tiled_args["tile_t"] = tile_t
if overlap_t is not None:
tiled_args["overlap_t"] = overlap_t
output = self._decode_tiled_owned(samples, **tiled_args)
elif dims == 1 or self.extra_1d_channel is not None:
args.pop("tile_y")
output = self.decode_tiled_1d(samples, **args)
elif dims == 2:
@ -1154,6 +1202,8 @@ class VAE:
else:
pixels_in = pixels_in.to(self.device)
out = self.first_stage_model.encode(pixels_in)
if isinstance(out, tuple):
out = out[0]
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
@ -1173,12 +1223,18 @@ class VAE:
if self.latent_dim == 3:
tile = 256
overlap = tile // 4
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
if getattr(self.first_stage_model, "comfy_handles_tiling", False):
samples = self._encode_tiled_owned(pixel_samples, tile_x=tile, tile_y=tile, overlap=overlap)
else:
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
samples = self.encode_tiled_1d(pixel_samples)
else:
samples = self.encode_tiled_(pixel_samples)
formatter = getattr(self.first_stage_model, "comfy_format_encoded", None)
if formatter is not None:
samples = formatter(samples)
return samples
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
@ -1186,7 +1242,7 @@ class VAE:
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
dims = self.latent_dim
pixel_samples = pixel_samples.movedim(-1, 1)
if dims == 3:
if dims == 3 and pixel_samples.ndim < 5:
if not self.not_video:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
else:
@ -1210,21 +1266,39 @@ class VAE:
elif dims == 2:
samples = self.encode_tiled_(pixel_samples, **args)
elif dims == 3:
if tile_t is not None:
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
if getattr(self.first_stage_model, "comfy_handles_tiling", False):
tiled_args = {}
if tile_x is not None:
tiled_args["tile_x"] = tile_x
if tile_y is not None:
tiled_args["tile_y"] = tile_y
if overlap is not None:
tiled_args["overlap"] = overlap
if tile_t is not None:
tiled_args["tile_t"] = tile_t
if overlap_t is not None:
tiled_args["overlap_t"] = overlap_t
samples = self._encode_tiled_owned(pixel_samples, **tiled_args)
else:
tile_t_latent = 9999
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
if tile_t is not None:
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
else:
tile_t_latent = 9999
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
if overlap_t is None:
args["overlap"] = (1, overlap, overlap)
else:
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
maximum = pixel_samples.shape[2]
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
spatial_overlap = overlap if overlap is not None else 64
if overlap_t is None:
args["overlap"] = (1, spatial_overlap, spatial_overlap)
else:
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), spatial_overlap, spatial_overlap)
maximum = pixel_samples.shape[2]
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
formatter = getattr(self.first_stage_model, "comfy_format_encoded", None)
if formatter is not None:
samples = formatter(samples)
return samples
def get_sd(self):
@ -1752,6 +1826,17 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (model, clip, vae)
def _set_model_config_inference_dtype(model_config, dtype, manual_cast_dtype, device):
set_dtype = model_config.set_inference_dtype
parameters = inspect.signature(set_dtype).parameters
supports_device = "device" in parameters or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in parameters.values())
if supports_device:
set_dtype(dtype, manual_cast_dtype, device=device)
else:
set_dtype(dtype, manual_cast_dtype)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False):
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
@ -1859,7 +1944,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
_set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device)
if model_config.clip_vision_prefix is not None:
if output_clipvision:
@ -2000,7 +2085,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
_set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device)
if custom_operations is not None:
model_config.custom_operations = custom_operations