From e4fb3a3572c94d8f2ef73ddd18d2a6966ed5a1e5 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 18 Dec 2025 14:45:33 -0800 Subject: [PATCH] Support loading Wan/Qwen VAEs with different in/out channels. (#11405) --- comfy/ldm/wan/vae.py | 11 +++++++---- comfy/sd.py | 3 ++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index ccbb25822..08315f1a8 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -227,6 +227,7 @@ class Encoder3d(nn.Module): def __init__(self, dim=128, z_dim=4, + input_channels=3, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], @@ -245,7 +246,7 @@ class Encoder3d(nn.Module): scale = 1.0 # init block - self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) + self.conv1 = CausalConv3d(input_channels, dims[0], 3, padding=1) # downsample blocks downsamples = [] @@ -331,6 +332,7 @@ class Decoder3d(nn.Module): def __init__(self, dim=128, z_dim=4, + output_channels=3, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], @@ -378,7 +380,7 @@ class Decoder3d(nn.Module): # output blocks self.head = nn.Sequential( RMS_norm(out_dim, images=False), nn.SiLU(), - CausalConv3d(out_dim, 3, 3, padding=1)) + CausalConv3d(out_dim, output_channels, 3, padding=1)) def forward(self, x, feat_cache=None, feat_idx=[0]): ## conv1 @@ -449,6 +451,7 @@ class WanVAE(nn.Module): num_res_blocks=2, attn_scales=[], temperal_downsample=[True, True, False], + image_channels=3, dropout=0.0): super().__init__() self.dim = dim @@ -460,11 +463,11 @@ class WanVAE(nn.Module): self.temperal_upsample = temperal_downsample[::-1] # modules - self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, + self.encoder = Encoder3d(dim, z_dim * 2, image_channels, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout) self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) self.conv2 = CausalConv3d(z_dim, z_dim, 1) - self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, + self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout) def encode(self, x): diff --git a/comfy/sd.py b/comfy/sd.py index 1cad98aef..f95c78892 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -546,7 +546,8 @@ class VAE: self.downscale_index_formula = (4, 8, 8) self.latent_dim = 3 self.latent_channels = 16 - ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0} + self.output_channels = sd["encoder.conv1.weight"].shape[1] + ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0} self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig) self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)