From 5039ec8fea26433537d3bb60bd59fc0dc1662518 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 1 Apr 2026 18:14:43 +0300 Subject: [PATCH] Fix VAE usage --- comfy_extras/nodes_model_patch.py | 42 +++++++++++++++++++++++++++---- 1 file changed, 37 insertions(+), 5 deletions(-) diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index f1ba2ed18..be19f3b99 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -271,11 +271,23 @@ class ModelPatchLoader: operations=comfy.ops.manual_cast) elif 'model.control_model.input_hint_block.0.weight' in sd or 'control_model.input_hint_block.0.weight' in sd: prefix_replace = {} - if 'model.control_model.input_hint_block.0.weight' in sd: + has_full_prefix = 'model.control_model.input_hint_block.0.weight' in sd + if has_full_prefix: prefix_replace["model.control_model."] = "control_model." prefix_replace["model.diffusion_model.project_modules."] = "project_modules." + + # Extract denoise_encoder weights before filter_keys discards them + de_prefix = "first_stage_model.denoise_encoder." + denoise_encoder_sd = {} + for k in list(sd.keys()): + if k.startswith(de_prefix): + denoise_encoder_sd[k[len(de_prefix):]] = sd.pop(k) + sd = comfy.utils.state_dict_prefix_replace(sd, prefix_replace, filter_keys=True) + sd.pop("control_model.mask_LQ", None) model = comfy.ldm.supir.supir_modules.SUPIR(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) + if denoise_encoder_sd: + model.denoise_encoder_sd = denoise_encoder_sd model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) model.load_state_dict(sd, assign=model_patcher.is_dynamic()) @@ -585,7 +597,8 @@ class SUPIRApply(io.ComfyNode): inputs=[ io.Model.Input("model"), io.ModelPatch.Input("model_patch"), - io.Latent.Input("latent"), + io.Vae.Input("vae"), + io.Image.Input("image"), io.Float.Input("strength_start", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Control strength at the start of sampling (high sigma)."), io.Float.Input("strength_end", default=1.0, min=0.0, max=10.0, step=0.01, @@ -599,15 +612,34 @@ class SUPIRApply(io.ComfyNode): ) @classmethod - def execute(cls, *, model: io.Model.Type, model_patch: io.ModelPatch.Type, latent: io.Latent.Type, + def _encode_with_denoise_encoder(cls, vae, model_patch, image): + """Encode using denoise_encoder weights from SUPIR checkpoint if available.""" + denoise_sd = getattr(model_patch.model, 'denoise_encoder_sd', None) + if not denoise_sd: + return vae.encode(image) + + # Patch VAE encoder with denoise_encoder weights, encode, then unpatch + patches = {f"encoder.{k}": (v,) for k, v in denoise_sd.items()} + vae.patcher.add_patches(patches, strength_patch=1.0, strength_model=0.0) + try: + return vae.encode(image) + finally: + vae.patcher.unpatch_model() + + @classmethod + def execute(cls, *, model: io.Model.Type, model_patch: io.ModelPatch.Type, vae: io.Vae.Type, image: io.Image.Type, strength_start: float, strength_end: float, restore_cfg: float, restore_cfg_s_tmin: float) -> io.NodeOutput: model_patched = model.clone() - hint_latent = model.get_model_object("latent_format").process_in(latent["samples"]) + hint_latent = model.get_model_object("latent_format").process_in( + cls._encode_with_denoise_encoder(vae, model_patch, image[:, :, :, :3])) patch = SUPIRPatch(model_patch, model_patch.model.project_modules, hint_latent, strength_start, strength_end) patch.register(model_patched) if restore_cfg > 0.0: - x_center = hint_latent.clone() + # Round-trip to match original pipeline: decode hint, re-encode with regular VAE + latent_format = model.get_model_object("latent_format") + decoded = vae.decode(latent_format.process_out(hint_latent)) + x_center = latent_format.process_in(vae.encode(decoded[:, :, :, :3])) sigma_max = 14.6146 def restore_cfg_function(args):