From 3597b27515ee827190765596ab39a78e3f03f1ae Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 15:46:39 +1000 Subject: [PATCH] models: Use CoreModelPatcher Use CoreModelPatcher for all internal ModelPatcher implementations. This drives conditional use of the aimdo feature, while making sure custom node packs get to keep ModelPatcher unchanged for the moment. --- comfy/audio_encoders/audio_encoders.py | 4 +-- comfy/clip_vision.py | 4 +-- comfy/controlnet.py | 2 +- comfy/ldm/hunyuan_video/upsampler.py | 4 +-- comfy/model_base.py | 4 +-- comfy/sd.py | 39 +++++++++++++++----------- comfy_extras/nodes_model_patch.py | 6 ++-- 7 files changed, 34 insertions(+), 29 deletions(-) diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py index 46ef21c95..16998af94 100644 --- a/comfy/audio_encoders/audio_encoders.py +++ b/comfy/audio_encoders/audio_encoders.py @@ -25,11 +25,11 @@ class AudioEncoderModel(): elif model_type == "whisper3": self.model = WhisperLargeV3(**model_config) self.model.eval() - self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.model_sample_rate = 16000 def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=False) + return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) def get_sd(self): return self.model.state_dict() diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 66f2a9d9c..e4dd4e678 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -47,10 +47,10 @@ class ClipVisionModel(): self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast) self.model.eval() - self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=False) + return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) def get_sd(self): return self.model.state_dict() diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 0b5e30f52..9e1e704e0 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -203,7 +203,7 @@ class ControlNet(ControlBase): self.control_model = control_model self.load_device = load_device if control_model is not None: - self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) + self.control_model_wrapped = comfy.model_patcher.CoreModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) self.compression_ratio = compression_ratio self.global_average_pooling = global_average_pooling diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py index 51b6d1da8..1f68144e2 100644 --- a/comfy/ldm/hunyuan_video/upsampler.py +++ b/comfy/ldm/hunyuan_video/upsampler.py @@ -109,10 +109,10 @@ class HunyuanVideo15SRModel(): self.model_class = UPSAMPLERS.get(model_type) self.model = self.model_class(**config).eval() - self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=True) + return self.model.load_state_dict(sd, strict=True, assign=self.patcher.is_dynamic()) def get_sd(self): return self.model.state_dict() diff --git a/comfy/model_base.py b/comfy/model_base.py index b9e2abdb9..f4b037bb9 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -298,7 +298,7 @@ class BaseModel(torch.nn.Module): return out - def load_model_weights(self, sd, unet_prefix=""): + def load_model_weights(self, sd, unet_prefix="", assign=False): to_load = {} keys = list(sd.keys()) for k in keys: @@ -306,7 +306,7 @@ class BaseModel(torch.nn.Module): to_load[k[len(unet_prefix):]] = sd.pop(k) to_load = self.model_config.process_unet_state_dict(to_load) - m, u = self.diffusion_model.load_state_dict(to_load, strict=False) + m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=assign) if len(m) > 0: logging.warning("unet missing: {}".format(m)) diff --git a/comfy/sd.py b/comfy/sd.py index f83164d23..ca268cbbe 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -128,7 +128,7 @@ class CLIP: logging.warning("Had to shift TE back.") self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) - self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) #Match torch.float32 hardcode upcast in TE implemention self.patcher.set_model_compute_dtype(torch.float32) self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram @@ -288,7 +288,7 @@ class CLIP: def load_sd(self, sd, full_model=False): if full_model: - return self.cond_stage_model.load_state_dict(sd, strict=False) + return self.cond_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) else: return self.cond_stage_model.load_sd(sd) @@ -665,13 +665,6 @@ class VAE: self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() - m, u = self.first_stage_model.load_state_dict(sd, strict=False) - if len(m) > 0: - logging.warning("Missing VAE keys {}".format(m)) - - if len(u) > 0: - logging.debug("Leftover VAE keys {}".format(u)) - if device is None: device = model_management.vae_device() self.device = device @@ -682,7 +675,18 @@ class VAE: self.first_stage_model.to(self.vae_dtype) self.output_device = model_management.intermediate_device() - self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) + mp = comfy.model_patcher.CoreModelPatcher + if self.disable_offload: + mp = comfy.model_patcher.ModelPatcher + self.patcher = mp(self.first_stage_model, load_device=self.device, offload_device=offload_device) + + m, u = self.first_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) + if len(m) > 0: + logging.warning("Missing VAE keys {}".format(m)) + + if len(u) > 0: + logging.debug("Leftover VAE keys {}".format(u)) + logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) self.model_size() @@ -1315,7 +1319,7 @@ def load_gligen(ckpt_path): model = gligen.load_gligen(data) if model_management.should_use_fp16(): model = model.half() - return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) + return comfy.model_patcher.CoreModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) def model_detection_error_hint(path, state_dict): filename = os.path.basename(path) @@ -1403,7 +1407,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c if output_model: inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device) - model.load_model_weights(sd, diffusion_model_prefix) + model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) + model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic()) if output_vae: vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True) @@ -1446,7 +1451,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c logging.debug("left over keys: {}".format(left_over)) if output_model: - model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) if inital_load_device != torch.device("cpu"): logging.info("loaded diffusion model directly to GPU") model_management.load_models_gpu([model_patcher], force_full_load=True) @@ -1538,13 +1542,14 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): model_config.optimizations["fp8"] = True model = model_config.get_model(new_sd, "") - model = model.to(offload_device) - model.load_model_weights(new_sd, "") + model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=offload_device) + if not model_management.is_device_cpu(offload_device): + model.to(offload_device) + model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic()) left_over = sd.keys() if len(left_over) > 0: logging.info("left over keys in diffusion model: {}".format(left_over)) - return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device) - + return model_patcher def load_diffusion_model(unet_path, model_options={}): sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True) diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index 1355b3c93..a3303ea1a 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -254,9 +254,9 @@ class ModelPatchLoader: config['broken'] = True model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config) - model.load_state_dict(sd) - model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) - return (model,) + model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) + model.load_state_dict(sd, assign=self.model_patcher.is_dynamic()) + return (model_patcher,) class DiffSynthCnetPatch: