diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index 40c592a2b..4a503dde4 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -1189,7 +1189,6 @@ class Decoder3D(nn.Module): # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] - print(f"slicing_up_num: {slicing_up_num}") for i, up_block_type in enumerate(up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] @@ -1450,6 +1449,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): def encode(self, x: torch.FloatTensor): if x.ndim == 4: x = x.unsqueeze(2) + x = x.to(next(self.parameters()).dtype) p = super().encode(x).latent_dist z = p.sample().squeeze(2) return z, p diff --git a/comfy/sd.py b/comfy/sd.py index 79b17073f..186a69703 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -268,7 +268,10 @@ class CLIP: class VAE: def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None): if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format - sd = diffusers_convert.convert_vae_state_dict(sd) + if (metadata is not None and metadata["keep_diffusers_format"] == "true"): + pass + else: + sd = diffusers_convert.convert_vae_state_dict(sd) self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower) self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) @@ -326,6 +329,15 @@ class VAE: self.first_stage_model = StageC_coder() self.downscale_ratio = 32 self.latent_channels = 16 + elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2 + self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper() + self.memory_used_decode = lambda shape, dtype: (2000 * shape[1] * shape[2] * shape[3] * (4 * 8 * 8)) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (1000 * max(shape[1], 5) * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.working_dtypes = [torch.bfloat16, torch.float32] + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) + self.downscale_index_formula = (4, 8, 8) + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) + self.upscale_index_formula = (4, 8, 8) elif "decoder.conv_in.weight" in sd: #default SD1.x/SD2.x VAE parameters ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} @@ -393,17 +405,6 @@ class VAE: self.downscale_index_formula = (8, 32, 32) self.working_dtypes = [torch.bfloat16, torch.float32] - elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2 - self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper() - ddconfig["conv3d"] = True - ddconfig["time_compress"] = 4 - self.memory_used_decode = lambda shape, dtype: (2000 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype) - self.memory_used_encode = lambda shape, dtype: (1000 * max(shape[2], 5) * shape[3] * shape[4]) * model_management.dtype_size(dtype) - self.working_dtypes = [torch.bfloat16, torch.float32] - self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) - self.downscale_index_formula = (4, 8, 8) - self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) - self.upscale_index_formula = (4, 8, 8) elif "decoder.conv_in.conv.weight" in sd: ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index e2fa10427..e83e37c1d 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -105,27 +105,46 @@ class SeedVR2InputProcessing(io.ComfyNode): category="image/video", inputs = [ io.Image.Input("images"), - io.Int.Input("resolution_height"), - io.Int.Input("resolution_width") + io.Vae.Input("vae"), + io.Int.Input("resolution_height", default = 1280, min = 120), # // + io.Int.Input("resolution_width", default = 720, min = 120) # just non-zero value ], outputs = [ - io.Image.Output("processed_images") + io.Latent.Output("vae_conditioning") ] ) @classmethod - def execute(cls, images, resolution_height, resolution_width): - images = images.permute(0, 3, 1, 2) + def execute(cls, images, vae, resolution_height, resolution_width): + vae = vae.first_stage_model + scale = 0.9152; shift = 0 + + if images.dim() != 5: # add the t dim + images = images.unsqueeze(0) + images = images.permute(0, 1, 4, 2, 3) + + b, t, c, h, w = images.shape + images = images.reshape(b * t, c, h, w) + max_area = ((resolution_height * resolution_width)** 0.5) ** 2 clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0)) normalize = Normalize(0.5, 0.5) images = area_resize(images, max_area) + images = clip(images) images = crop(images, (16, 16)) images = normalize(images) - images = rearrange(images, "t c h w -> c t h w") + _, _, new_h, new_w = images.shape + + images = images.reshape(b, t, c, new_h, new_w) images = cut_videos(images) - return io.NodeOutput(images) + + images = rearrange(images, "b t c h w -> b c t h w") + latent = vae.encode(images)[0] + + latent = (latent - shift) * scale + + return io.NodeOutput({"samples": latent}) class SeedVR2Conditioning(io.ComfyNode): @classmethod @@ -150,8 +169,8 @@ class SeedVR2Conditioning(io.ComfyNode): pos_cond = text_positive_conditioning[0][0] neg_cond = text_negative_conditioning[0][0] - noises = [torch.randn_like(latent) for latent in vae_conditioning] - aug_noises = [torch.randn_like(latent) for latent in vae_conditioning] + noises = torch.randn_like(vae_conditioning) + aug_noises = torch.randn_like(vae_conditioning) cond_noise_scale = 0.0 t = ( @@ -165,8 +184,8 @@ class SeedVR2Conditioning(io.ComfyNode): pos_shape = pos_cond.shape[1] neg_shape = neg_shape.shape[1] - diff = abs(pos_shape.shape[1] - neg_shape.shape[1]) - if pos_shape.shape[1] > neg_shape.shape[1]: + diff = abs(pos_shape - neg_shape) + if pos_shape > neg_shape: neg_cond = F.pad(neg_cond, (0, 0, 0, diff)) else: pos_cond = F.pad(pos_cond, (0, 0, 0, diff))