mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
Fix VAE usage
This commit is contained in:
parent
c35def1107
commit
5039ec8fea
@ -271,11 +271,23 @@ class ModelPatchLoader:
|
|||||||
operations=comfy.ops.manual_cast)
|
operations=comfy.ops.manual_cast)
|
||||||
elif 'model.control_model.input_hint_block.0.weight' in sd or 'control_model.input_hint_block.0.weight' in sd:
|
elif 'model.control_model.input_hint_block.0.weight' in sd or 'control_model.input_hint_block.0.weight' in sd:
|
||||||
prefix_replace = {}
|
prefix_replace = {}
|
||||||
if 'model.control_model.input_hint_block.0.weight' in sd:
|
has_full_prefix = 'model.control_model.input_hint_block.0.weight' in sd
|
||||||
|
if has_full_prefix:
|
||||||
prefix_replace["model.control_model."] = "control_model."
|
prefix_replace["model.control_model."] = "control_model."
|
||||||
prefix_replace["model.diffusion_model.project_modules."] = "project_modules."
|
prefix_replace["model.diffusion_model.project_modules."] = "project_modules."
|
||||||
|
|
||||||
|
# Extract denoise_encoder weights before filter_keys discards them
|
||||||
|
de_prefix = "first_stage_model.denoise_encoder."
|
||||||
|
denoise_encoder_sd = {}
|
||||||
|
for k in list(sd.keys()):
|
||||||
|
if k.startswith(de_prefix):
|
||||||
|
denoise_encoder_sd[k[len(de_prefix):]] = sd.pop(k)
|
||||||
|
|
||||||
sd = comfy.utils.state_dict_prefix_replace(sd, prefix_replace, filter_keys=True)
|
sd = comfy.utils.state_dict_prefix_replace(sd, prefix_replace, filter_keys=True)
|
||||||
|
sd.pop("control_model.mask_LQ", None)
|
||||||
model = comfy.ldm.supir.supir_modules.SUPIR(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
model = comfy.ldm.supir.supir_modules.SUPIR(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||||
|
if denoise_encoder_sd:
|
||||||
|
model.denoise_encoder_sd = denoise_encoder_sd
|
||||||
|
|
||||||
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||||
model.load_state_dict(sd, assign=model_patcher.is_dynamic())
|
model.load_state_dict(sd, assign=model_patcher.is_dynamic())
|
||||||
@ -585,7 +597,8 @@ class SUPIRApply(io.ComfyNode):
|
|||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model"),
|
io.Model.Input("model"),
|
||||||
io.ModelPatch.Input("model_patch"),
|
io.ModelPatch.Input("model_patch"),
|
||||||
io.Latent.Input("latent"),
|
io.Vae.Input("vae"),
|
||||||
|
io.Image.Input("image"),
|
||||||
io.Float.Input("strength_start", default=1.0, min=0.0, max=10.0, step=0.01,
|
io.Float.Input("strength_start", default=1.0, min=0.0, max=10.0, step=0.01,
|
||||||
tooltip="Control strength at the start of sampling (high sigma)."),
|
tooltip="Control strength at the start of sampling (high sigma)."),
|
||||||
io.Float.Input("strength_end", default=1.0, min=0.0, max=10.0, step=0.01,
|
io.Float.Input("strength_end", default=1.0, min=0.0, max=10.0, step=0.01,
|
||||||
@ -599,15 +612,34 @@ class SUPIRApply(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, *, model: io.Model.Type, model_patch: io.ModelPatch.Type, latent: io.Latent.Type,
|
def _encode_with_denoise_encoder(cls, vae, model_patch, image):
|
||||||
|
"""Encode using denoise_encoder weights from SUPIR checkpoint if available."""
|
||||||
|
denoise_sd = getattr(model_patch.model, 'denoise_encoder_sd', None)
|
||||||
|
if not denoise_sd:
|
||||||
|
return vae.encode(image)
|
||||||
|
|
||||||
|
# Patch VAE encoder with denoise_encoder weights, encode, then unpatch
|
||||||
|
patches = {f"encoder.{k}": (v,) for k, v in denoise_sd.items()}
|
||||||
|
vae.patcher.add_patches(patches, strength_patch=1.0, strength_model=0.0)
|
||||||
|
try:
|
||||||
|
return vae.encode(image)
|
||||||
|
finally:
|
||||||
|
vae.patcher.unpatch_model()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, *, model: io.Model.Type, model_patch: io.ModelPatch.Type, vae: io.Vae.Type, image: io.Image.Type,
|
||||||
strength_start: float, strength_end: float, restore_cfg: float, restore_cfg_s_tmin: float) -> io.NodeOutput:
|
strength_start: float, strength_end: float, restore_cfg: float, restore_cfg_s_tmin: float) -> io.NodeOutput:
|
||||||
model_patched = model.clone()
|
model_patched = model.clone()
|
||||||
hint_latent = model.get_model_object("latent_format").process_in(latent["samples"])
|
hint_latent = model.get_model_object("latent_format").process_in(
|
||||||
|
cls._encode_with_denoise_encoder(vae, model_patch, image[:, :, :, :3]))
|
||||||
patch = SUPIRPatch(model_patch, model_patch.model.project_modules, hint_latent, strength_start, strength_end)
|
patch = SUPIRPatch(model_patch, model_patch.model.project_modules, hint_latent, strength_start, strength_end)
|
||||||
patch.register(model_patched)
|
patch.register(model_patched)
|
||||||
|
|
||||||
if restore_cfg > 0.0:
|
if restore_cfg > 0.0:
|
||||||
x_center = hint_latent.clone()
|
# Round-trip to match original pipeline: decode hint, re-encode with regular VAE
|
||||||
|
latent_format = model.get_model_object("latent_format")
|
||||||
|
decoded = vae.decode(latent_format.process_out(hint_latent))
|
||||||
|
x_center = latent_format.process_in(vae.encode(decoded[:, :, :, :3]))
|
||||||
sigma_max = 14.6146
|
sigma_max = 14.6146
|
||||||
|
|
||||||
def restore_cfg_function(args):
|
def restore_cfg_function(args):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user