Fix VAEEncodeForInpaint to support WAN VAE tuple downscale_ratio

Use vae.spacial_compression_encode() instead of directly accessing
downscale_ratio to handle both standard VAEs (int) and WAN VAEs (tuple).

Addresses reviewer feedback on PR #11259.
This commit is contained in:
ChrisFab16 2025-12-19 10:40:50 +01:00 committed by Rattus
parent 6ca3d5c011
commit 38d7145076

View File

@ -374,14 +374,15 @@ class VAEEncodeForInpaint:
CATEGORY = "latent/inpaint" CATEGORY = "latent/inpaint"
def encode(self, vae, pixels, mask, grow_mask_by=6): def encode(self, vae, pixels, mask, grow_mask_by=6):
x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio downscale_ratio = vae.spacial_compression_encode()
y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio x = (pixels.shape[1] // downscale_ratio) * downscale_ratio
y = (pixels.shape[2] // downscale_ratio) * downscale_ratio
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
pixels = pixels.clone() pixels = pixels.clone()
if pixels.shape[1] != x or pixels.shape[2] != y: if pixels.shape[1] != x or pixels.shape[2] != y:
x_offset = (pixels.shape[1] % vae.downscale_ratio) // 2 x_offset = (pixels.shape[1] % downscale_ratio) // 2
y_offset = (pixels.shape[2] % vae.downscale_ratio) // 2 y_offset = (pixels.shape[2] % downscale_ratio) // 2
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:] pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset] mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]