Disable dynamic_vram when weight hooks applied (#12653)

* sd: add support for clip model reconstruction

* nodes: SetClipHooks: Demote the dynamic model patcher

* mp: Make dynamic_disable more robust

The backup need to not be cloned. In addition add a delegate object
to ModelPatcherDynamic so that non-cloning code can do
ModelPatcherDynamic demotion

* sampler_helpers: Demote to non-dynamic model patcher when hooking

* code rabbit review comments
This commit is contained in:
rattus 2026-02-28 13:50:18 -08:00 committed by GitHub
parent 1f6744162f
commit 5f41584e96
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 64 additions and 19 deletions

View File

@ -308,15 +308,22 @@ class ModelPatcher:
def get_free_memory(self, device):
return comfy.model_management.get_free_memory(device)
def clone(self, disable_dynamic=False):
def get_clone_model_override(self):
return self.model, (self.backup, self.object_patches_backup, self.pinned)
def clone(self, disable_dynamic=False, model_override=None):
class_ = self.__class__
model = self.model
if self.is_dynamic() and disable_dynamic:
class_ = ModelPatcher
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
model = temp_model_patcher.model
if model_override is None:
if self.cached_patcher_init is None:
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
model_override = temp_model_patcher.get_clone_model_override()
if model_override is None:
model_override = self.get_clone_model_override()
n = class_(model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
n = class_(model_override[0], self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
@ -325,13 +332,12 @@ class ModelPatcher:
n.object_patches = self.object_patches.copy()
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
n.model_options = comfy.utils.deepcopy_list_dict(self.model_options)
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
n.parent = self
n.pinned = self.pinned
n.force_cast_weights = self.force_cast_weights
n.backup, n.object_patches_backup, n.pinned = model_override[1]
# attachments
n.attachments = {}
for k in self.attachments:
@ -1435,6 +1441,7 @@ class ModelPatcherDynamic(ModelPatcher):
del self.model.model_loaded_weight_memory
if not hasattr(self.model, "dynamic_vbars"):
self.model.dynamic_vbars = {}
self.non_dynamic_delegate_model = None
assert load_device is not None
def is_dynamic(self):
@ -1669,4 +1676,10 @@ class ModelPatcherDynamic(ModelPatcher):
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
pass
def get_non_dynamic_delegate(self):
model_patcher = self.clone(disable_dynamic=True, model_override=self.non_dynamic_delegate_model)
self.non_dynamic_delegate_model = model_patcher.get_clone_model_override()
return model_patcher
CoreModelPatcher = ModelPatcher

View File

@ -66,6 +66,18 @@ def convert_cond(cond):
out.append(temp)
return out
def cond_has_hooks(cond):
for c in cond:
temp = c[1]
if "hooks" in temp:
return True
if "control" in temp:
control = temp["control"]
extra_hooks = control.get_extra_hooks()
if len(extra_hooks) > 0:
return True
return False
def get_additional_models(conds, dtype):
"""loads additional models in conditioning"""
cnets: list[ControlBase] = []

View File

@ -946,6 +946,8 @@ class CFGGuider:
def inner_set_conds(self, conds):
for k in conds:
if self.model_patcher.is_dynamic() and comfy.sampler_helpers.cond_has_hooks(conds[k]):
self.model_patcher = self.model_patcher.get_non_dynamic_delegate()
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
def __call__(self, *args, **kwargs):

View File

@ -204,7 +204,7 @@ def load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip
class CLIP:
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}):
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}, disable_dynamic=False):
if no_init:
return
params = target.params.copy()
@ -233,7 +233,8 @@ class CLIP:
model_management.archive_model_dtypes(self.cond_stage_model)
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
#Match torch.float32 hardcode upcast in TE implemention
self.patcher.set_model_compute_dtype(torch.float32)
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
@ -267,9 +268,9 @@ class CLIP:
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
self.tokenizer_options = {}
def clone(self):
def clone(self, disable_dynamic=False):
n = CLIP(no_init=True)
n.patcher = self.patcher.clone()
n.patcher = self.patcher.clone(disable_dynamic=disable_dynamic)
n.cond_stage_model = self.cond_stage_model
n.tokenizer = self.tokenizer
n.layer_idx = self.layer_idx
@ -1164,14 +1165,21 @@ class CLIPType(Enum):
LONGCAT_IMAGE = 26
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
def load_clip_model_patcher(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False):
clip = load_clip(ckpt_paths, embedding_directory, clip_type, model_options, disable_dynamic)
return clip.patcher
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False):
clip_data = []
for p in ckpt_paths:
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
if model_options.get("custom_operations", None) is None:
sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata)
clip_data.append(sd)
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
clip = load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options, disable_dynamic=disable_dynamic)
clip.patcher.cached_patcher_init = (load_clip_model_patcher, (ckpt_paths, embedding_directory, clip_type, model_options))
return clip
class TEModel(Enum):
@ -1276,7 +1284,7 @@ def llama_detect(clip_data):
return {}
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False):
clip_data = state_dicts
class EmptyClass:
@ -1496,7 +1504,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
parameters += comfy.utils.calculate_parameters(c)
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options)
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options, disable_dynamic=disable_dynamic)
return clip
def load_gligen(ckpt_path):
@ -1541,8 +1549,10 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
if output_model:
if output_model and out[0] is not None:
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
if output_clip and out[1] is not None:
out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options))
return out
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
@ -1553,6 +1563,14 @@ def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None,
disable_dynamic=disable_dynamic)
return model
def load_checkpoint_guess_config_clip_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
_, clip, *_ = load_checkpoint_guess_config(ckpt_path, False, True, False,
embedding_directory=embedding_directory, output_model=False,
model_options=model_options,
te_model_options=te_model_options,
disable_dynamic=disable_dynamic)
return clip.patcher
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None, disable_dynamic=False):
clip = None
clipvision = None
@ -1638,7 +1656,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
clip_sd = model_config.process_clip_state_dict(sd)
if len(clip_sd) > 0:
parameters = comfy.utils.calculate_parameters(clip_sd)
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options)
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options, disable_dynamic=disable_dynamic)
else:
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")

View File

@ -248,7 +248,7 @@ class SetClipHooks:
def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
if hooks is not None:
clip = clip.clone()
clip = clip.clone(disable_dynamic=True)
if apply_to_conds:
clip.apply_hooks_to_conds = hooks
clip.patcher.forced_hooks = hooks.clone()