diff --git a/comfy/ldm/supir/supir_patch.py b/comfy/ldm/supir/supir_patch.py index d1f17d982..b67ab4cd8 100644 --- a/comfy/ldm/supir/supir_patch.py +++ b/comfy/ldm/supir/supir_patch.py @@ -27,10 +27,10 @@ class SUPIRPatch: if self.cached_features is not None: return x = kwargs["x"] - batch_size = x.shape[0] + b = x.shape[0] hint = self.hint_latent.to(device=x.device, dtype=x.dtype) - if hint.shape[0] < batch_size: - hint = hint.repeat(batch_size // hint.shape[0], 1, 1, 1)[:batch_size] + if hint.shape[0] != b: + hint = hint.expand(b, -1, -1, -1) if hint.shape[0] == 1 else hint.repeat((b + hint.shape[0] - 1) // hint.shape[0], 1, 1, 1)[:b] self.cached_features = self.model_patch.model.control_model( hint, kwargs["timesteps"], x, kwargs["context"], kwargs["y"] diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index 6d77a0788..748559a6b 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -271,10 +271,12 @@ class ModelPatchLoader: operations=comfy.ops.manual_cast) elif 'model.control_model.input_hint_block.0.weight' in sd or 'control_model.input_hint_block.0.weight' in sd: prefix_replace = {} - has_full_prefix = 'model.control_model.input_hint_block.0.weight' in sd - if has_full_prefix: + if 'model.control_model.input_hint_block.0.weight' in sd: prefix_replace["model.control_model."] = "control_model." prefix_replace["model.diffusion_model.project_modules."] = "project_modules." + else: + prefix_replace["control_model."] = "control_model." + prefix_replace["project_modules."] = "project_modules." # Extract denoise_encoder weights before filter_keys discards them de_prefix = "first_stage_model.denoise_encoder." @@ -618,13 +620,15 @@ class SUPIRApply(io.ComfyNode): if not denoise_sd: return vae.encode(image) - # Patch VAE encoder with denoise_encoder weights, encode, then unpatch + # Clone VAE patcher, apply denoise_encoder weights to clone, encode + orig_patcher = vae.patcher + vae.patcher = orig_patcher.clone() patches = {f"encoder.{k}": (v,) for k, v in denoise_sd.items()} vae.patcher.add_patches(patches, strength_patch=1.0, strength_model=0.0) try: return vae.encode(image) finally: - vae.patcher.unpatch_model() + vae.patcher = orig_patcher @classmethod def execute(cls, *, model: io.Model.Type, model_patch: io.ModelPatch.Type, vae: io.Vae.Type, image: io.Image.Type, @@ -651,8 +655,9 @@ class SUPIRApply(io.ComfyNode): s = sigma.item() if s > restore_cfg_s_tmin: ref = x_center.to(device=denoised.device, dtype=denoised.dtype) - if ref.shape[0] < denoised.shape[0]: - ref = ref.repeat(denoised.shape[0] // ref.shape[0], 1, 1, 1)[:denoised.shape[0]] + b = denoised.shape[0] + if ref.shape[0] != b: + ref = ref.expand(b, -1, -1, -1) if ref.shape[0] == 1 else ref.repeat((b + ref.shape[0] - 1) // ref.shape[0], 1, 1, 1)[:b] sigma_val = sigma.view(-1, 1, 1, 1) if sigma.dim() > 0 else sigma d_center = denoised - ref denoised = denoised - d_center * ((sigma_val / sigma_max) ** restore_cfg)