This commit is contained in:
kijai 2026-04-01 18:57:33 +03:00
parent a0b496c438
commit 5b7b80fd08
2 changed files with 14 additions and 9 deletions

View File

@ -27,10 +27,10 @@ class SUPIRPatch:
if self.cached_features is not None:
return
x = kwargs["x"]
batch_size = x.shape[0]
b = x.shape[0]
hint = self.hint_latent.to(device=x.device, dtype=x.dtype)
if hint.shape[0] < batch_size:
hint = hint.repeat(batch_size // hint.shape[0], 1, 1, 1)[:batch_size]
if hint.shape[0] != b:
hint = hint.expand(b, -1, -1, -1) if hint.shape[0] == 1 else hint.repeat((b + hint.shape[0] - 1) // hint.shape[0], 1, 1, 1)[:b]
self.cached_features = self.model_patch.model.control_model(
hint, kwargs["timesteps"], x,
kwargs["context"], kwargs["y"]

View File

@ -271,10 +271,12 @@ class ModelPatchLoader:
operations=comfy.ops.manual_cast)
elif 'model.control_model.input_hint_block.0.weight' in sd or 'control_model.input_hint_block.0.weight' in sd:
prefix_replace = {}
has_full_prefix = 'model.control_model.input_hint_block.0.weight' in sd
if has_full_prefix:
if 'model.control_model.input_hint_block.0.weight' in sd:
prefix_replace["model.control_model."] = "control_model."
prefix_replace["model.diffusion_model.project_modules."] = "project_modules."
else:
prefix_replace["control_model."] = "control_model."
prefix_replace["project_modules."] = "project_modules."
# Extract denoise_encoder weights before filter_keys discards them
de_prefix = "first_stage_model.denoise_encoder."
@ -618,13 +620,15 @@ class SUPIRApply(io.ComfyNode):
if not denoise_sd:
return vae.encode(image)
# Patch VAE encoder with denoise_encoder weights, encode, then unpatch
# Clone VAE patcher, apply denoise_encoder weights to clone, encode
orig_patcher = vae.patcher
vae.patcher = orig_patcher.clone()
patches = {f"encoder.{k}": (v,) for k, v in denoise_sd.items()}
vae.patcher.add_patches(patches, strength_patch=1.0, strength_model=0.0)
try:
return vae.encode(image)
finally:
vae.patcher.unpatch_model()
vae.patcher = orig_patcher
@classmethod
def execute(cls, *, model: io.Model.Type, model_patch: io.ModelPatch.Type, vae: io.Vae.Type, image: io.Image.Type,
@ -651,8 +655,9 @@ class SUPIRApply(io.ComfyNode):
s = sigma.item()
if s > restore_cfg_s_tmin:
ref = x_center.to(device=denoised.device, dtype=denoised.dtype)
if ref.shape[0] < denoised.shape[0]:
ref = ref.repeat(denoised.shape[0] // ref.shape[0], 1, 1, 1)[:denoised.shape[0]]
b = denoised.shape[0]
if ref.shape[0] != b:
ref = ref.expand(b, -1, -1, -1) if ref.shape[0] == 1 else ref.repeat((b + ref.shape[0] - 1) // ref.shape[0], 1, 1, 1)[:b]
sigma_val = sigma.view(-1, 1, 1, 1) if sigma.dim() > 0 else sigma
d_center = denoised - ref
denoised = denoised - d_center * ((sigma_val / sigma_max) ** restore_cfg)