mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-20 03:23:00 +08:00
Support loading Wan/Qwen VAEs with different in/out channels. (#11405)
This commit is contained in:
parent
e8ebbe668e
commit
e4fb3a3572
@ -227,6 +227,7 @@ class Encoder3d(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
dim=128,
|
dim=128,
|
||||||
z_dim=4,
|
z_dim=4,
|
||||||
|
input_channels=3,
|
||||||
dim_mult=[1, 2, 4, 4],
|
dim_mult=[1, 2, 4, 4],
|
||||||
num_res_blocks=2,
|
num_res_blocks=2,
|
||||||
attn_scales=[],
|
attn_scales=[],
|
||||||
@ -245,7 +246,7 @@ class Encoder3d(nn.Module):
|
|||||||
scale = 1.0
|
scale = 1.0
|
||||||
|
|
||||||
# init block
|
# init block
|
||||||
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
self.conv1 = CausalConv3d(input_channels, dims[0], 3, padding=1)
|
||||||
|
|
||||||
# downsample blocks
|
# downsample blocks
|
||||||
downsamples = []
|
downsamples = []
|
||||||
@ -331,6 +332,7 @@ class Decoder3d(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
dim=128,
|
dim=128,
|
||||||
z_dim=4,
|
z_dim=4,
|
||||||
|
output_channels=3,
|
||||||
dim_mult=[1, 2, 4, 4],
|
dim_mult=[1, 2, 4, 4],
|
||||||
num_res_blocks=2,
|
num_res_blocks=2,
|
||||||
attn_scales=[],
|
attn_scales=[],
|
||||||
@ -378,7 +380,7 @@ class Decoder3d(nn.Module):
|
|||||||
# output blocks
|
# output blocks
|
||||||
self.head = nn.Sequential(
|
self.head = nn.Sequential(
|
||||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||||
CausalConv3d(out_dim, 3, 3, padding=1))
|
CausalConv3d(out_dim, output_channels, 3, padding=1))
|
||||||
|
|
||||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
## conv1
|
## conv1
|
||||||
@ -449,6 +451,7 @@ class WanVAE(nn.Module):
|
|||||||
num_res_blocks=2,
|
num_res_blocks=2,
|
||||||
attn_scales=[],
|
attn_scales=[],
|
||||||
temperal_downsample=[True, True, False],
|
temperal_downsample=[True, True, False],
|
||||||
|
image_channels=3,
|
||||||
dropout=0.0):
|
dropout=0.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
@ -460,11 +463,11 @@ class WanVAE(nn.Module):
|
|||||||
self.temperal_upsample = temperal_downsample[::-1]
|
self.temperal_upsample = temperal_downsample[::-1]
|
||||||
|
|
||||||
# modules
|
# modules
|
||||||
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
self.encoder = Encoder3d(dim, z_dim * 2, image_channels, dim_mult, num_res_blocks,
|
||||||
attn_scales, self.temperal_downsample, dropout)
|
attn_scales, self.temperal_downsample, dropout)
|
||||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||||
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks,
|
||||||
attn_scales, self.temperal_upsample, dropout)
|
attn_scales, self.temperal_upsample, dropout)
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
|
|||||||
@ -546,7 +546,8 @@ class VAE:
|
|||||||
self.downscale_index_formula = (4, 8, 8)
|
self.downscale_index_formula = (4, 8, 8)
|
||||||
self.latent_dim = 3
|
self.latent_dim = 3
|
||||||
self.latent_channels = 16
|
self.latent_channels = 16
|
||||||
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
|
self.output_channels = sd["encoder.conv1.weight"].shape[1]
|
||||||
|
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0}
|
||||||
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user