Fix VAE usage

This commit is contained in:
kijai 2026-04-01 18:14:43 +03:00
parent c35def1107
commit 5039ec8fea

View File

@ -271,11 +271,23 @@ class ModelPatchLoader:
operations=comfy.ops.manual_cast)
elif 'model.control_model.input_hint_block.0.weight' in sd or 'control_model.input_hint_block.0.weight' in sd:
prefix_replace = {}
if 'model.control_model.input_hint_block.0.weight' in sd:
has_full_prefix = 'model.control_model.input_hint_block.0.weight' in sd
if has_full_prefix:
prefix_replace["model.control_model."] = "control_model."
prefix_replace["model.diffusion_model.project_modules."] = "project_modules."
# Extract denoise_encoder weights before filter_keys discards them
de_prefix = "first_stage_model.denoise_encoder."
denoise_encoder_sd = {}
for k in list(sd.keys()):
if k.startswith(de_prefix):
denoise_encoder_sd[k[len(de_prefix):]] = sd.pop(k)
sd = comfy.utils.state_dict_prefix_replace(sd, prefix_replace, filter_keys=True)
sd.pop("control_model.mask_LQ", None)
model = comfy.ldm.supir.supir_modules.SUPIR(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
if denoise_encoder_sd:
model.denoise_encoder_sd = denoise_encoder_sd
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
model.load_state_dict(sd, assign=model_patcher.is_dynamic())
@ -585,7 +597,8 @@ class SUPIRApply(io.ComfyNode):
inputs=[
io.Model.Input("model"),
io.ModelPatch.Input("model_patch"),
io.Latent.Input("latent"),
io.Vae.Input("vae"),
io.Image.Input("image"),
io.Float.Input("strength_start", default=1.0, min=0.0, max=10.0, step=0.01,
tooltip="Control strength at the start of sampling (high sigma)."),
io.Float.Input("strength_end", default=1.0, min=0.0, max=10.0, step=0.01,
@ -599,15 +612,34 @@ class SUPIRApply(io.ComfyNode):
)
@classmethod
def execute(cls, *, model: io.Model.Type, model_patch: io.ModelPatch.Type, latent: io.Latent.Type,
def _encode_with_denoise_encoder(cls, vae, model_patch, image):
"""Encode using denoise_encoder weights from SUPIR checkpoint if available."""
denoise_sd = getattr(model_patch.model, 'denoise_encoder_sd', None)
if not denoise_sd:
return vae.encode(image)
# Patch VAE encoder with denoise_encoder weights, encode, then unpatch
patches = {f"encoder.{k}": (v,) for k, v in denoise_sd.items()}
vae.patcher.add_patches(patches, strength_patch=1.0, strength_model=0.0)
try:
return vae.encode(image)
finally:
vae.patcher.unpatch_model()
@classmethod
def execute(cls, *, model: io.Model.Type, model_patch: io.ModelPatch.Type, vae: io.Vae.Type, image: io.Image.Type,
strength_start: float, strength_end: float, restore_cfg: float, restore_cfg_s_tmin: float) -> io.NodeOutput:
model_patched = model.clone()
hint_latent = model.get_model_object("latent_format").process_in(latent["samples"])
hint_latent = model.get_model_object("latent_format").process_in(
cls._encode_with_denoise_encoder(vae, model_patch, image[:, :, :, :3]))
patch = SUPIRPatch(model_patch, model_patch.model.project_modules, hint_latent, strength_start, strength_end)
patch.register(model_patched)
if restore_cfg > 0.0:
x_center = hint_latent.clone()
# Round-trip to match original pipeline: decode hint, re-encode with regular VAE
latent_format = model.get_model_object("latent_format")
decoded = vae.decode(latent_format.process_out(hint_latent))
x_center = latent_format.process_in(vae.encode(decoded[:, :, :, :3]))
sigma_max = 14.6146
def restore_cfg_function(args):