mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-06 01:37:45 +08:00
Disable dynamic_vram when weight hooks applied (#12653)
* sd: add support for clip model reconstruction * nodes: SetClipHooks: Demote the dynamic model patcher * mp: Make dynamic_disable more robust The backup need to not be cloned. In addition add a delegate object to ModelPatcherDynamic so that non-cloning code can do ModelPatcherDynamic demotion * sampler_helpers: Demote to non-dynamic model patcher when hooking * code rabbit review comments
This commit is contained in:
parent
1f6744162f
commit
5f41584e96
@ -308,15 +308,22 @@ class ModelPatcher:
|
|||||||
def get_free_memory(self, device):
|
def get_free_memory(self, device):
|
||||||
return comfy.model_management.get_free_memory(device)
|
return comfy.model_management.get_free_memory(device)
|
||||||
|
|
||||||
def clone(self, disable_dynamic=False):
|
def get_clone_model_override(self):
|
||||||
|
return self.model, (self.backup, self.object_patches_backup, self.pinned)
|
||||||
|
|
||||||
|
def clone(self, disable_dynamic=False, model_override=None):
|
||||||
class_ = self.__class__
|
class_ = self.__class__
|
||||||
model = self.model
|
|
||||||
if self.is_dynamic() and disable_dynamic:
|
if self.is_dynamic() and disable_dynamic:
|
||||||
class_ = ModelPatcher
|
class_ = ModelPatcher
|
||||||
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
|
if model_override is None:
|
||||||
model = temp_model_patcher.model
|
if self.cached_patcher_init is None:
|
||||||
|
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")
|
||||||
|
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
|
||||||
|
model_override = temp_model_patcher.get_clone_model_override()
|
||||||
|
if model_override is None:
|
||||||
|
model_override = self.get_clone_model_override()
|
||||||
|
|
||||||
n = class_(model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
|
n = class_(model_override[0], self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
|
||||||
n.patches = {}
|
n.patches = {}
|
||||||
for k in self.patches:
|
for k in self.patches:
|
||||||
n.patches[k] = self.patches[k][:]
|
n.patches[k] = self.patches[k][:]
|
||||||
@ -325,13 +332,12 @@ class ModelPatcher:
|
|||||||
n.object_patches = self.object_patches.copy()
|
n.object_patches = self.object_patches.copy()
|
||||||
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
|
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
|
||||||
n.model_options = comfy.utils.deepcopy_list_dict(self.model_options)
|
n.model_options = comfy.utils.deepcopy_list_dict(self.model_options)
|
||||||
n.backup = self.backup
|
|
||||||
n.object_patches_backup = self.object_patches_backup
|
|
||||||
n.parent = self
|
n.parent = self
|
||||||
n.pinned = self.pinned
|
|
||||||
|
|
||||||
n.force_cast_weights = self.force_cast_weights
|
n.force_cast_weights = self.force_cast_weights
|
||||||
|
|
||||||
|
n.backup, n.object_patches_backup, n.pinned = model_override[1]
|
||||||
|
|
||||||
# attachments
|
# attachments
|
||||||
n.attachments = {}
|
n.attachments = {}
|
||||||
for k in self.attachments:
|
for k in self.attachments:
|
||||||
@ -1435,6 +1441,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
del self.model.model_loaded_weight_memory
|
del self.model.model_loaded_weight_memory
|
||||||
if not hasattr(self.model, "dynamic_vbars"):
|
if not hasattr(self.model, "dynamic_vbars"):
|
||||||
self.model.dynamic_vbars = {}
|
self.model.dynamic_vbars = {}
|
||||||
|
self.non_dynamic_delegate_model = None
|
||||||
assert load_device is not None
|
assert load_device is not None
|
||||||
|
|
||||||
def is_dynamic(self):
|
def is_dynamic(self):
|
||||||
@ -1669,4 +1676,10 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
|
def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_non_dynamic_delegate(self):
|
||||||
|
model_patcher = self.clone(disable_dynamic=True, model_override=self.non_dynamic_delegate_model)
|
||||||
|
self.non_dynamic_delegate_model = model_patcher.get_clone_model_override()
|
||||||
|
return model_patcher
|
||||||
|
|
||||||
|
|
||||||
CoreModelPatcher = ModelPatcher
|
CoreModelPatcher = ModelPatcher
|
||||||
|
|||||||
@ -66,6 +66,18 @@ def convert_cond(cond):
|
|||||||
out.append(temp)
|
out.append(temp)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def cond_has_hooks(cond):
|
||||||
|
for c in cond:
|
||||||
|
temp = c[1]
|
||||||
|
if "hooks" in temp:
|
||||||
|
return True
|
||||||
|
if "control" in temp:
|
||||||
|
control = temp["control"]
|
||||||
|
extra_hooks = control.get_extra_hooks()
|
||||||
|
if len(extra_hooks) > 0:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def get_additional_models(conds, dtype):
|
def get_additional_models(conds, dtype):
|
||||||
"""loads additional models in conditioning"""
|
"""loads additional models in conditioning"""
|
||||||
cnets: list[ControlBase] = []
|
cnets: list[ControlBase] = []
|
||||||
|
|||||||
@ -946,6 +946,8 @@ class CFGGuider:
|
|||||||
|
|
||||||
def inner_set_conds(self, conds):
|
def inner_set_conds(self, conds):
|
||||||
for k in conds:
|
for k in conds:
|
||||||
|
if self.model_patcher.is_dynamic() and comfy.sampler_helpers.cond_has_hooks(conds[k]):
|
||||||
|
self.model_patcher = self.model_patcher.get_non_dynamic_delegate()
|
||||||
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
|
self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k])
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
|
|||||||
38
comfy/sd.py
38
comfy/sd.py
@ -204,7 +204,7 @@ def load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip
|
|||||||
|
|
||||||
|
|
||||||
class CLIP:
|
class CLIP:
|
||||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}):
|
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}, disable_dynamic=False):
|
||||||
if no_init:
|
if no_init:
|
||||||
return
|
return
|
||||||
params = target.params.copy()
|
params = target.params.copy()
|
||||||
@ -233,7 +233,8 @@ class CLIP:
|
|||||||
model_management.archive_model_dtypes(self.cond_stage_model)
|
model_management.archive_model_dtypes(self.cond_stage_model)
|
||||||
|
|
||||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
|
||||||
|
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||||
#Match torch.float32 hardcode upcast in TE implemention
|
#Match torch.float32 hardcode upcast in TE implemention
|
||||||
self.patcher.set_model_compute_dtype(torch.float32)
|
self.patcher.set_model_compute_dtype(torch.float32)
|
||||||
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||||
@ -267,9 +268,9 @@ class CLIP:
|
|||||||
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
|
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
|
||||||
self.tokenizer_options = {}
|
self.tokenizer_options = {}
|
||||||
|
|
||||||
def clone(self):
|
def clone(self, disable_dynamic=False):
|
||||||
n = CLIP(no_init=True)
|
n = CLIP(no_init=True)
|
||||||
n.patcher = self.patcher.clone()
|
n.patcher = self.patcher.clone(disable_dynamic=disable_dynamic)
|
||||||
n.cond_stage_model = self.cond_stage_model
|
n.cond_stage_model = self.cond_stage_model
|
||||||
n.tokenizer = self.tokenizer
|
n.tokenizer = self.tokenizer
|
||||||
n.layer_idx = self.layer_idx
|
n.layer_idx = self.layer_idx
|
||||||
@ -1164,14 +1165,21 @@ class CLIPType(Enum):
|
|||||||
LONGCAT_IMAGE = 26
|
LONGCAT_IMAGE = 26
|
||||||
|
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
|
||||||
|
def load_clip_model_patcher(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False):
|
||||||
|
clip = load_clip(ckpt_paths, embedding_directory, clip_type, model_options, disable_dynamic)
|
||||||
|
return clip.patcher
|
||||||
|
|
||||||
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False):
|
||||||
clip_data = []
|
clip_data = []
|
||||||
for p in ckpt_paths:
|
for p in ckpt_paths:
|
||||||
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
|
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
|
||||||
if model_options.get("custom_operations", None) is None:
|
if model_options.get("custom_operations", None) is None:
|
||||||
sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata)
|
sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata)
|
||||||
clip_data.append(sd)
|
clip_data.append(sd)
|
||||||
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
clip = load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options, disable_dynamic=disable_dynamic)
|
||||||
|
clip.patcher.cached_patcher_init = (load_clip_model_patcher, (ckpt_paths, embedding_directory, clip_type, model_options))
|
||||||
|
return clip
|
||||||
|
|
||||||
|
|
||||||
class TEModel(Enum):
|
class TEModel(Enum):
|
||||||
@ -1276,7 +1284,7 @@ def llama_detect(clip_data):
|
|||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False):
|
||||||
clip_data = state_dicts
|
clip_data = state_dicts
|
||||||
|
|
||||||
class EmptyClass:
|
class EmptyClass:
|
||||||
@ -1496,7 +1504,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
parameters += comfy.utils.calculate_parameters(c)
|
parameters += comfy.utils.calculate_parameters(c)
|
||||||
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
||||||
|
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options)
|
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options, disable_dynamic=disable_dynamic)
|
||||||
return clip
|
return clip
|
||||||
|
|
||||||
def load_gligen(ckpt_path):
|
def load_gligen(ckpt_path):
|
||||||
@ -1541,8 +1549,10 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
|
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
|
||||||
if out is None:
|
if out is None:
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
||||||
if output_model:
|
if output_model and out[0] is not None:
|
||||||
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
|
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||||
|
if output_clip and out[1] is not None:
|
||||||
|
out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||||
@ -1553,6 +1563,14 @@ def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None,
|
|||||||
disable_dynamic=disable_dynamic)
|
disable_dynamic=disable_dynamic)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def load_checkpoint_guess_config_clip_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||||
|
_, clip, *_ = load_checkpoint_guess_config(ckpt_path, False, True, False,
|
||||||
|
embedding_directory=embedding_directory, output_model=False,
|
||||||
|
model_options=model_options,
|
||||||
|
te_model_options=te_model_options,
|
||||||
|
disable_dynamic=disable_dynamic)
|
||||||
|
return clip.patcher
|
||||||
|
|
||||||
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None, disable_dynamic=False):
|
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None, disable_dynamic=False):
|
||||||
clip = None
|
clip = None
|
||||||
clipvision = None
|
clipvision = None
|
||||||
@ -1638,7 +1656,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
clip_sd = model_config.process_clip_state_dict(sd)
|
clip_sd = model_config.process_clip_state_dict(sd)
|
||||||
if len(clip_sd) > 0:
|
if len(clip_sd) > 0:
|
||||||
parameters = comfy.utils.calculate_parameters(clip_sd)
|
parameters = comfy.utils.calculate_parameters(clip_sd)
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options)
|
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options, disable_dynamic=disable_dynamic)
|
||||||
else:
|
else:
|
||||||
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
|
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
|
||||||
|
|
||||||
|
|||||||
@ -248,7 +248,7 @@ class SetClipHooks:
|
|||||||
|
|
||||||
def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
|
def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None):
|
||||||
if hooks is not None:
|
if hooks is not None:
|
||||||
clip = clip.clone()
|
clip = clip.clone(disable_dynamic=True)
|
||||||
if apply_to_conds:
|
if apply_to_conds:
|
||||||
clip.apply_hooks_to_conds = hooks
|
clip.apply_hooks_to_conds = hooks
|
||||||
clip.patcher.forced_hooks = hooks.clone()
|
clip.patcher.forced_hooks = hooks.clone()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user