Trim/pad channels in VAE code. (#11406)

This commit is contained in:
comfyanonymous 2025-12-18 15:22:38 -08:00 committed by GitHub
parent e4fb3a3572
commit 6a2678ac65
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 11 deletions

View File

@ -321,6 +321,7 @@ class VAE:
self.latent_channels = 4 self.latent_channels = 4
self.latent_dim = 2 self.latent_dim = 2
self.output_channels = 3 self.output_channels = 3
self.pad_channel_value = None
self.process_input = lambda image: image * 2.0 - 1.0 self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
self.working_dtypes = [torch.bfloat16, torch.float32] self.working_dtypes = [torch.bfloat16, torch.float32]
@ -435,6 +436,7 @@ class VAE:
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype)
self.latent_channels = 64 self.latent_channels = 64
self.output_channels = 2 self.output_channels = 2
self.pad_channel_value = "replicate"
self.upscale_ratio = 2048 self.upscale_ratio = 2048
self.downscale_ratio = 2048 self.downscale_ratio = 2048
self.latent_dim = 1 self.latent_dim = 1
@ -547,6 +549,7 @@ class VAE:
self.latent_dim = 3 self.latent_dim = 3
self.latent_channels = 16 self.latent_channels = 16
self.output_channels = sd["encoder.conv1.weight"].shape[1] self.output_channels = sd["encoder.conv1.weight"].shape[1]
self.pad_channel_value = 1.0
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0} ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig) self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
@ -583,6 +586,7 @@ class VAE:
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
self.latent_channels = 8 self.latent_channels = 8
self.output_channels = 2 self.output_channels = 2
self.pad_channel_value = "replicate"
self.upscale_ratio = 4096 self.upscale_ratio = 4096
self.downscale_ratio = 4096 self.downscale_ratio = 4096
self.latent_dim = 2 self.latent_dim = 2
@ -691,17 +695,28 @@ class VAE:
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.") raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
def vae_encode_crop_pixels(self, pixels): def vae_encode_crop_pixels(self, pixels):
if not self.crop_input: if self.crop_input:
return pixels downscale_ratio = self.spacial_compression_encode()
downscale_ratio = self.spacial_compression_encode() dims = pixels.shape[1:-1]
for d in range(len(dims)):
x = (dims[d] // downscale_ratio) * downscale_ratio
x_offset = (dims[d] % downscale_ratio) // 2
if x != dims[d]:
pixels = pixels.narrow(d + 1, x_offset, x)
dims = pixels.shape[1:-1] if pixels.shape[-1] > self.output_channels:
for d in range(len(dims)): pixels = pixels[..., :self.output_channels]
x = (dims[d] // downscale_ratio) * downscale_ratio elif pixels.shape[-1] < self.output_channels:
x_offset = (dims[d] % downscale_ratio) // 2 if self.pad_channel_value is not None:
if x != dims[d]: if isinstance(self.pad_channel_value, str):
pixels = pixels.narrow(d + 1, x_offset, x) mode = self.pad_channel_value
value = None
else:
mode = "constant"
value = self.pad_channel_value
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
return pixels return pixels
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):

View File

@ -343,7 +343,7 @@ class VAEEncode:
CATEGORY = "latent" CATEGORY = "latent"
def encode(self, vae, pixels): def encode(self, vae, pixels):
t = vae.encode(pixels[:,:,:,:3]) t = vae.encode(pixels)
return ({"samples":t}, ) return ({"samples":t}, )
class VAEEncodeTiled: class VAEEncodeTiled:
@ -361,7 +361,7 @@ class VAEEncodeTiled:
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8): def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap) t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
return ({"samples": t}, ) return ({"samples": t}, )
class VAEEncodeForInpaint: class VAEEncodeForInpaint: