diff --git a/comfy/ldm/models/autoencoder.py b/comfy/ldm/models/autoencoder.py index 4f50810dc..582b44e69 100644 --- a/comfy/ldm/models/autoencoder.py +++ b/comfy/ldm/models/autoencoder.py @@ -155,6 +155,7 @@ class AutoencodingEngineLegacy(AutoencodingEngine): def __init__(self, embed_dim: int, **kwargs): self.max_batch_size = kwargs.pop("max_batch_size", None) ddconfig = kwargs.pop("ddconfig") + decoder_ddconfig = kwargs.pop("decoder_ddconfig", ddconfig) super().__init__( encoder_config={ "target": "comfy.ldm.modules.diffusionmodules.model.Encoder", @@ -162,7 +163,7 @@ class AutoencodingEngineLegacy(AutoencodingEngine): }, decoder_config={ "target": "comfy.ldm.modules.diffusionmodules.model.Decoder", - "params": ddconfig, + "params": decoder_ddconfig, }, **kwargs, ) diff --git a/comfy/sd.py b/comfy/sd.py index 5b6b59ea4..f331feefb 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -556,12 +556,19 @@ class VAE: old_memory_used_decode = self.memory_used_decode self.memory_used_decode = lambda shape, dtype: old_memory_used_decode(shape, dtype) * 4.0 + decoder_ch = sd['decoder.conv_in.weight'].shape[0] // ddconfig['ch_mult'][-1] + if decoder_ch != ddconfig['ch']: + decoder_ddconfig = ddconfig.copy() + decoder_ddconfig['ch'] = decoder_ch + else: + decoder_ddconfig = None + if 'post_quant_conv.weight' in sd: - self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) + self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1], **({"decoder_ddconfig": decoder_ddconfig} if decoder_ddconfig is not None else {})) else: self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig}, - decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig}) + decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': decoder_ddconfig if decoder_ddconfig is not None else ddconfig}) elif "decoder.layers.1.layers.0.beta" in sd: config = {} param_key = None