diff --git a/nodes.py b/nodes.py index 8d28a725d..cf0325dfe 100644 --- a/nodes.py +++ b/nodes.py @@ -374,14 +374,32 @@ class VAEEncodeForInpaint: CATEGORY = "latent/inpaint" def encode(self, vae, pixels, mask, grow_mask_by=6): - x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio - y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio + # Handle WAN VAE downscale_ratio which can be a tuple (function, min, max) + if isinstance(vae.downscale_ratio, tuple): + # For WAN VAEs: (lambda function, min_value, max_value) + downscale_func = vae.downscale_ratio[0] + x_downscale = downscale_func(pixels.shape[1]) + y_downscale = downscale_func(pixels.shape[2]) + x = (pixels.shape[1] // x_downscale) * x_downscale + y = (pixels.shape[2] // y_downscale) * y_downscale + else: + # Standard integer downscale_ratio + x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio + y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio + mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") pixels = pixels.clone() if pixels.shape[1] != x or pixels.shape[2] != y: - x_offset = (pixels.shape[1] % vae.downscale_ratio) // 2 - y_offset = (pixels.shape[2] % vae.downscale_ratio) // 2 + if isinstance(vae.downscale_ratio, tuple): + downscale_func = vae.downscale_ratio[0] + x_downscale = downscale_func(pixels.shape[1]) + y_downscale = downscale_func(pixels.shape[2]) + x_offset = (pixels.shape[1] % x_downscale) // 2 + y_offset = (pixels.shape[2] % y_downscale) // 2 + else: + x_offset = (pixels.shape[1] % vae.downscale_ratio) // 2 + y_offset = (pixels.shape[2] % vae.downscale_ratio) // 2 pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:] mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]