mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-24 00:39:30 +08:00
Add SeedVR2 VAE support
This commit is contained in:
parent
cd18c4460a
commit
a7ea0c2773
1807
comfy/ldm/seedvr/vae.py
Normal file
1807
comfy/ldm/seedvr/vae.py
Normal file
File diff suppressed because it is too large
Load Diff
125
comfy/sd.py
125
comfy/sd.py
@ -1,3 +1,4 @@
|
||||
import inspect
|
||||
import json
|
||||
import torch
|
||||
from enum import Enum
|
||||
@ -16,6 +17,7 @@ import comfy.ldm.cosmos.vae
|
||||
import comfy.ldm.wan.vae
|
||||
import comfy.ldm.wan.vae2_2
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.seedvr.vae
|
||||
import comfy.ldm.triposplat.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
import comfy.ldm.cogvideo.vae
|
||||
@ -467,8 +469,10 @@ class CLIP:
|
||||
|
||||
class VAE:
|
||||
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
|
||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||
is_seedvr2_vae = "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd
|
||||
if not is_seedvr2_vae and 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||
if metadata is None or metadata.get("keep_diffusers_format") != "true":
|
||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||
|
||||
if model_management.is_amd():
|
||||
VAE_KL_MEM_RATIO = 2.73
|
||||
@ -540,6 +544,20 @@ class VAE:
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.downscale_ratio = 32
|
||||
self.latent_channels = 16
|
||||
elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2
|
||||
self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper()
|
||||
self.latent_channels = 16
|
||||
self.latent_dim = 3
|
||||
self.disable_offload = True
|
||||
self.memory_used_decode = lambda shape, dtype: self.first_stage_model.comfy_memory_used_decode(shape)
|
||||
self.memory_used_encode = lambda shape, dtype: (max(shape[2], 5) * shape[3] * shape[4] * 64) * model_management.dtype_size(dtype)
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||
self.upscale_index_formula = (4, 8, 8)
|
||||
self.process_input = lambda image: image * 2.0 - 1.0
|
||||
self.crop_input = False
|
||||
elif "decoder.conv_in.weight" in sd:
|
||||
if sd['decoder.conv_in.weight'].shape[1] == 64:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||
@ -1006,6 +1024,10 @@ class VAE:
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
|
||||
|
||||
def _decode_tiled_owned(self, samples, **kwargs):
|
||||
out = self.first_stage_model.decode_tiled(samples.to(self.vae_dtype).to(self.device), **kwargs)
|
||||
return self.process_output(out.to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True))
|
||||
|
||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
||||
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||
@ -1042,6 +1064,11 @@ class VAE:
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
||||
|
||||
def _encode_tiled_owned(self, pixel_samples, **kwargs):
|
||||
x = self.process_input(pixel_samples).to(self.vae_dtype).to(self.device)
|
||||
out = self.first_stage_model.encode_tiled(x, **kwargs)
|
||||
return out.to(device=self.output_device, dtype=self.vae_output_dtype())
|
||||
|
||||
def decode(self, samples_in, vae_options={}):
|
||||
self.throw_exception_if_invalid()
|
||||
pixel_samples = None
|
||||
@ -1089,11 +1116,19 @@ class VAE:
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||
elif dims == 2:
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
if getattr(self.first_stage_model, "comfy_handles_tiling", False):
|
||||
tile = 256 // self.spacial_compression_decode()
|
||||
overlap = tile // 4
|
||||
pixel_samples = self._decode_tiled_owned(samples_in, tile_x=tile, tile_y=tile, overlap=overlap)
|
||||
else:
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
elif dims == 3:
|
||||
tile = 256 // self.spacial_compression_decode()
|
||||
overlap = tile // 4
|
||||
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
if getattr(self.first_stage_model, "comfy_handles_tiling", False):
|
||||
pixel_samples = self._decode_tiled_owned(samples_in, tile_x=tile, tile_y=tile, overlap=overlap)
|
||||
else:
|
||||
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
|
||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||
return pixel_samples
|
||||
@ -1112,7 +1147,20 @@ class VAE:
|
||||
args["overlap"] = overlap
|
||||
|
||||
with model_management.cuda_device_context(self.device):
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
if getattr(self.first_stage_model, "comfy_handles_tiling", False) and dims in (2, 3):
|
||||
tiled_args = {}
|
||||
if tile_x is not None:
|
||||
tiled_args["tile_x"] = tile_x
|
||||
if tile_y is not None:
|
||||
tiled_args["tile_y"] = tile_y
|
||||
if overlap is not None:
|
||||
tiled_args["overlap"] = overlap
|
||||
if tile_t is not None:
|
||||
tiled_args["tile_t"] = tile_t
|
||||
if overlap_t is not None:
|
||||
tiled_args["overlap_t"] = overlap_t
|
||||
output = self._decode_tiled_owned(samples, **tiled_args)
|
||||
elif dims == 1 or self.extra_1d_channel is not None:
|
||||
args.pop("tile_y")
|
||||
output = self.decode_tiled_1d(samples, **args)
|
||||
elif dims == 2:
|
||||
@ -1154,6 +1202,8 @@ class VAE:
|
||||
else:
|
||||
pixels_in = pixels_in.to(self.device)
|
||||
out = self.first_stage_model.encode(pixels_in)
|
||||
if isinstance(out, tuple):
|
||||
out = out[0]
|
||||
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||
if samples is None:
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
@ -1173,12 +1223,18 @@ class VAE:
|
||||
if self.latent_dim == 3:
|
||||
tile = 256
|
||||
overlap = tile // 4
|
||||
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
if getattr(self.first_stage_model, "comfy_handles_tiling", False):
|
||||
samples = self._encode_tiled_owned(pixel_samples, tile_x=tile, tile_y=tile, overlap=overlap)
|
||||
else:
|
||||
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
|
||||
samples = self.encode_tiled_1d(pixel_samples)
|
||||
else:
|
||||
samples = self.encode_tiled_(pixel_samples)
|
||||
|
||||
formatter = getattr(self.first_stage_model, "comfy_format_encoded", None)
|
||||
if formatter is not None:
|
||||
samples = formatter(samples)
|
||||
return samples
|
||||
|
||||
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||
@ -1186,7 +1242,7 @@ class VAE:
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
dims = self.latent_dim
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
if dims == 3:
|
||||
if dims == 3 and pixel_samples.ndim < 5:
|
||||
if not self.not_video:
|
||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||
else:
|
||||
@ -1210,21 +1266,39 @@ class VAE:
|
||||
elif dims == 2:
|
||||
samples = self.encode_tiled_(pixel_samples, **args)
|
||||
elif dims == 3:
|
||||
if tile_t is not None:
|
||||
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
|
||||
if getattr(self.first_stage_model, "comfy_handles_tiling", False):
|
||||
tiled_args = {}
|
||||
if tile_x is not None:
|
||||
tiled_args["tile_x"] = tile_x
|
||||
if tile_y is not None:
|
||||
tiled_args["tile_y"] = tile_y
|
||||
if overlap is not None:
|
||||
tiled_args["overlap"] = overlap
|
||||
if tile_t is not None:
|
||||
tiled_args["tile_t"] = tile_t
|
||||
if overlap_t is not None:
|
||||
tiled_args["overlap_t"] = overlap_t
|
||||
samples = self._encode_tiled_owned(pixel_samples, **tiled_args)
|
||||
else:
|
||||
tile_t_latent = 9999
|
||||
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
|
||||
if tile_t is not None:
|
||||
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
|
||||
else:
|
||||
tile_t_latent = 9999
|
||||
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
|
||||
|
||||
if overlap_t is None:
|
||||
args["overlap"] = (1, overlap, overlap)
|
||||
else:
|
||||
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
|
||||
maximum = pixel_samples.shape[2]
|
||||
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
|
||||
spatial_overlap = overlap if overlap is not None else 64
|
||||
if overlap_t is None:
|
||||
args["overlap"] = (1, spatial_overlap, spatial_overlap)
|
||||
else:
|
||||
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), spatial_overlap, spatial_overlap)
|
||||
maximum = pixel_samples.shape[2]
|
||||
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
|
||||
|
||||
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
|
||||
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
|
||||
|
||||
formatter = getattr(self.first_stage_model, "comfy_format_encoded", None)
|
||||
if formatter is not None:
|
||||
samples = formatter(samples)
|
||||
return samples
|
||||
|
||||
def get_sd(self):
|
||||
@ -1752,6 +1826,17 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
|
||||
return (model, clip, vae)
|
||||
|
||||
|
||||
def _set_model_config_inference_dtype(model_config, dtype, manual_cast_dtype, device):
|
||||
set_dtype = model_config.set_inference_dtype
|
||||
parameters = inspect.signature(set_dtype).parameters
|
||||
supports_device = "device" in parameters or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in parameters.values())
|
||||
if supports_device:
|
||||
set_dtype(dtype, manual_cast_dtype, device=device)
|
||||
else:
|
||||
set_dtype(dtype, manual_cast_dtype)
|
||||
|
||||
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
|
||||
@ -1859,7 +1944,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||
else:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
_set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device)
|
||||
|
||||
if model_config.clip_vision_prefix is not None:
|
||||
if output_clipvision:
|
||||
@ -2000,7 +2085,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
|
||||
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||
else:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
_set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device)
|
||||
|
||||
if custom_operations is not None:
|
||||
model_config.custom_operations = custom_operations
|
||||
|
||||
Loading…
Reference in New Issue
Block a user