fixed manual vae loading

This commit is contained in:
Yousef Rafat 2025-12-30 18:44:57 +02:00
parent fadc7839cc
commit 84fa155071
3 changed files with 3 additions and 11 deletions

View File

@ -1878,8 +1878,6 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
if latent.ndim == 4: if latent.ndim == 4:
latent = latent.unsqueeze(2) latent = latent.unsqueeze(2)
target_device = comfy.model_management.get_torch_device()
self.decoder.to(target_device)
if self.tiled_args.get("enable_tiling", None) is not None: if self.tiled_args.get("enable_tiling", None) is not None:
self.enable_tiling = self.tiled_args.pop("enable_tiling", False) self.enable_tiling = self.tiled_args.pop("enable_tiling", False)

View File

@ -379,8 +379,8 @@ class VAE:
self.latent_channels = 16 self.latent_channels = 16
elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2 elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2
self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper() self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper()
self.memory_used_decode = lambda shape, dtype: (2000 * shape[1] * shape[2] * shape[3] * (4 * 8 * 8)) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (10 * shape[1] * shape[2] * shape[3] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (1000 * max(shape[1], 5) * shape[2] * shape[3]) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (10 * max(shape[1], 5) * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float32] self.working_dtypes = [torch.bfloat16, torch.float32]
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8) self.downscale_index_formula = (4, 8, 8)

View File

@ -332,11 +332,8 @@ class SeedVR2InputProcessing(io.ComfyNode):
@classmethod @classmethod
def execute(cls, images, vae, resolution, spatial_tile_size, temporal_tile_size, spatial_overlap, enable_tiling): def execute(cls, images, vae, resolution, spatial_tile_size, temporal_tile_size, spatial_overlap, enable_tiling):
device = vae.patcher.load_device
offload_device = comfy.model_management.intermediate_device() comfy.model_management.load_models_gpu([vae.patcher])
main_device = comfy.model_management.get_torch_device()
images = images.to(main_device)
vae_model = vae.first_stage_model vae_model = vae.first_stage_model
scale = 0.9152; shift = 0 scale = 0.9152; shift = 0
if images.dim() != 5: # add the t dim if images.dim() != 5: # add the t dim
@ -360,8 +357,6 @@ class SeedVR2InputProcessing(io.ComfyNode):
images = cut_videos(images) images = cut_videos(images)
images = rearrange(images, "b t c h w -> b c t h w") images = rearrange(images, "b t c h w -> b c t h w")
images = images.to(device)
vae_model = vae_model.to(device)
# in case users a non-compatiable number for tiling # in case users a non-compatiable number for tiling
def make_divisible(val, divisor): def make_divisible(val, divisor):
@ -393,7 +388,6 @@ class SeedVR2InputProcessing(io.ComfyNode):
latent = rearrange(latent, "b c ... -> b ... c") latent = rearrange(latent, "b c ... -> b ... c")
latent = (latent - shift) * scale latent = (latent - shift) * scale
latent = latent.to(offload_device)
return io.NodeOutput({"samples": latent}) return io.NodeOutput({"samples": latent})