From f3968d16112d05b4d5b579f3fe9b7938c6f6c20c Mon Sep 17 00:00:00 2001 From: patientx Date: Tue, 3 Dec 2024 00:10:22 +0300 Subject: [PATCH] Revert "Merge branch 'comfyanonymous:master' into master" This reverts commit 605425bdd64d67182fb2a16b284b833994952c6a, reversing changes made to 74e6ad95f71f1d8405f04609ee1ab6a163d70276. --- comfy/controlnet.py | 22 +- comfy/hooks.py | 690 -------------- comfy/k_diffusion/sampling.py | 1 + .../modules/diffusionmodules/openaimodel.py | 17 - comfy/lora.py | 16 +- comfy/model_base.py | 12 - comfy/model_management.py | 198 ++-- comfy/model_patcher.py | 860 ++++-------------- comfy/patcher_extension.py | 156 ---- comfy/sampler_helpers.py | 72 +- comfy/samplers.py | 365 ++------ comfy/sd.py | 73 -- comfy_extras/nodes_clip_sdxl.py | 6 +- comfy_extras/nodes_flux.py | 5 +- comfy_extras/nodes_hooks.py | 697 -------------- comfy_extras/nodes_hunyuan.py | 4 +- comfy_extras/nodes_sd3.py | 3 +- execution.py | 2 +- main.py | 1 + nodes.py | 6 +- 20 files changed, 410 insertions(+), 2796 deletions(-) delete mode 100644 comfy/hooks.py delete mode 100644 comfy/patcher_extension.py delete mode 100644 comfy_extras/nodes_hooks.py diff --git a/comfy/controlnet.py b/comfy/controlnet.py index e6a0d1e59..a44f3725e 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -36,10 +36,6 @@ import comfy.cldm.mmdit import comfy.ldm.hydit.controlnet import comfy.ldm.flux.controlnet import comfy.cldm.dit_embedder -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from comfy.hooks import HookGroup - def broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] @@ -82,7 +78,6 @@ class ControlBase: self.concat_mask = False self.extra_concat_orig = [] self.extra_concat = None - self.extra_hooks: HookGroup = None self.preprocess_image = lambda a: a def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]): @@ -120,14 +115,6 @@ class ControlBase: if self.previous_controlnet is not None: out += self.previous_controlnet.get_models() return out - - def get_extra_hooks(self): - out = [] - if self.extra_hooks is not None: - out.append(self.extra_hooks) - if self.previous_controlnet is not None: - out += self.previous_controlnet.get_extra_hooks() - return out def copy_to(self, c): c.cond_hint_original = self.cond_hint_original @@ -143,7 +130,6 @@ class ControlBase: c.strength_type = self.strength_type c.concat_mask = self.concat_mask c.extra_concat_orig = self.extra_concat_orig.copy() - c.extra_hooks = self.extra_hooks.clone() if self.extra_hooks else None c.preprocess_image = self.preprocess_image def inference_memory_requirements(self, dtype): @@ -214,10 +200,10 @@ class ControlNet(ControlBase): self.concat_mask = concat_mask self.preprocess_image = preprocess_image - def get_control(self, x_noisy, t, cond, batched_number, transformer_options): + def get_control(self, x_noisy, t, cond, batched_number): control_prev = None if self.previous_controlnet is not None: - control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options) + control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) if self.timestep_range is not None: if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: @@ -772,10 +758,10 @@ class T2IAdapter(ControlBase): height = math.ceil(height / unshuffle_amount) * unshuffle_amount return width, height - def get_control(self, x_noisy, t, cond, batched_number, transformer_options): + def get_control(self, x_noisy, t, cond, batched_number): control_prev = None if self.previous_controlnet is not None: - control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options) + control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number) if self.timestep_range is not None: if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]: diff --git a/comfy/hooks.py b/comfy/hooks.py deleted file mode 100644 index ccb8183b9..000000000 --- a/comfy/hooks.py +++ /dev/null @@ -1,690 +0,0 @@ -from __future__ import annotations -from typing import TYPE_CHECKING, Callable -import enum -import math -import torch -import numpy as np -import itertools - -if TYPE_CHECKING: - from comfy.model_patcher import ModelPatcher, PatcherInjection - from comfy.model_base import BaseModel - from comfy.sd import CLIP -import comfy.lora -import comfy.model_management -import comfy.patcher_extension -from node_helpers import conditioning_set_values - -class EnumHookMode(enum.Enum): - MinVram = "minvram" - MaxSpeed = "maxspeed" - -class EnumHookType(enum.Enum): - Weight = "weight" - Patch = "patch" - ObjectPatch = "object_patch" - AddModels = "add_models" - Callbacks = "callbacks" - Wrappers = "wrappers" - SetInjections = "add_injections" - -class EnumWeightTarget(enum.Enum): - Model = "model" - Clip = "clip" - -class _HookRef: - pass - -# NOTE: this is an example of how the should_register function should look -def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): - return True - - -class Hook: - def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None, - hook_keyframe: 'HookKeyframeGroup'=None): - self.hook_type = hook_type - self.hook_ref = hook_ref if hook_ref else _HookRef() - self.hook_id = hook_id - self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup() - self.custom_should_register = default_should_register - self.auto_apply_to_nonpositive = False - - @property - def strength(self): - return self.hook_keyframe.strength - - def initialize_timesteps(self, model: 'BaseModel'): - self.reset() - self.hook_keyframe.initialize_timesteps(model) - - def reset(self): - self.hook_keyframe.reset() - - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: Hook = subtype() - c.hook_type = self.hook_type - c.hook_ref = self.hook_ref - c.hook_id = self.hook_id - c.hook_keyframe = self.hook_keyframe - c.custom_should_register = self.custom_should_register - # TODO: make this do something - c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive - return c - - def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): - return self.custom_should_register(self, model, model_options, target, registered) - - def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): - raise NotImplementedError("add_hook_patches should be defined for Hook subclasses") - - def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]): - pass - - def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]): - pass - - def __eq__(self, other: 'Hook'): - return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref - - def __hash__(self): - return hash(self.hook_ref) - -class WeightHook(Hook): - def __init__(self, strength_model=1.0, strength_clip=1.0): - super().__init__(hook_type=EnumHookType.Weight) - self.weights: dict = None - self.weights_clip: dict = None - self.need_weight_init = True - self._strength_model = strength_model - self._strength_clip = strength_clip - - @property - def strength_model(self): - return self._strength_model * self.strength - - @property - def strength_clip(self): - return self._strength_clip * self.strength - - def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): - if not self.should_register(model, model_options, target, registered): - return False - weights = None - if target == EnumWeightTarget.Model: - strength = self._strength_model - else: - strength = self._strength_clip - - if self.need_weight_init: - key_map = {} - if target == EnumWeightTarget.Model: - key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) - else: - key_map = comfy.lora.model_lora_keys_clip(model.model, key_map) - weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False) - else: - if target == EnumWeightTarget.Model: - weights = self.weights - else: - weights = self.weights_clip - k = model.add_hook_patches(hook=self, patches=weights, strength_patch=strength) - registered.append(self) - return True - # TODO: add logs about any keys that were not applied - - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: WeightHook = super().clone(subtype) - c.weights = self.weights - c.weights_clip = self.weights_clip - c.need_weight_init = self.need_weight_init - c._strength_model = self._strength_model - c._strength_clip = self._strength_clip - return c - -class PatchHook(Hook): - def __init__(self): - super().__init__(hook_type=EnumHookType.Patch) - self.patches: dict = None - - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: PatchHook = super().clone(subtype) - c.patches = self.patches - return c - # TODO: add functionality - -class ObjectPatchHook(Hook): - def __init__(self): - super().__init__(hook_type=EnumHookType.ObjectPatch) - self.object_patches: dict = None - - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: ObjectPatchHook = super().clone(subtype) - c.object_patches = self.object_patches - return c - # TODO: add functionality - -class AddModelsHook(Hook): - def __init__(self, key: str=None, models: list['ModelPatcher']=None): - super().__init__(hook_type=EnumHookType.AddModels) - self.key = key - self.models = models - self.append_when_same = True - - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: AddModelsHook = super().clone(subtype) - c.key = self.key - c.models = self.models.copy() if self.models else self.models - c.append_when_same = self.append_when_same - return c - # TODO: add functionality - -class CallbackHook(Hook): - def __init__(self, key: str=None, callback: Callable=None): - super().__init__(hook_type=EnumHookType.Callbacks) - self.key = key - self.callback = callback - - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: CallbackHook = super().clone(subtype) - c.key = self.key - c.callback = self.callback - return c - # TODO: add functionality - -class WrapperHook(Hook): - def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None): - super().__init__(hook_type=EnumHookType.Wrappers) - self.wrappers_dict = wrappers_dict - - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: WrapperHook = super().clone(subtype) - c.wrappers_dict = self.wrappers_dict - return c - - def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): - if not self.should_register(model, model_options, target, registered): - return False - add_model_options = {"transformer_options": self.wrappers_dict} - comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) - registered.append(self) - return True - -class SetInjectionsHook(Hook): - def __init__(self, key: str=None, injections: list['PatcherInjection']=None): - super().__init__(hook_type=EnumHookType.SetInjections) - self.key = key - self.injections = injections - - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: SetInjectionsHook = super().clone(subtype) - c.key = self.key - c.injections = self.injections.copy() if self.injections else self.injections - return c - - def add_hook_injections(self, model: 'ModelPatcher'): - # TODO: add functionality - pass - -class HookGroup: - def __init__(self): - self.hooks: list[Hook] = [] - - def add(self, hook: Hook): - if hook not in self.hooks: - self.hooks.append(hook) - - def contains(self, hook: Hook): - return hook in self.hooks - - def clone(self): - c = HookGroup() - for hook in self.hooks: - c.add(hook.clone()) - return c - - def clone_and_combine(self, other: 'HookGroup'): - c = self.clone() - if other is not None: - for hook in other.hooks: - c.add(hook.clone()) - return c - - def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'): - if hook_kf is None: - hook_kf = HookKeyframeGroup() - else: - hook_kf = hook_kf.clone() - for hook in self.hooks: - hook.hook_keyframe = hook_kf - - def get_dict_repr(self): - d: dict[EnumHookType, dict[Hook, None]] = {} - for hook in self.hooks: - with_type = d.setdefault(hook.hook_type, {}) - with_type[hook] = None - return d - - def get_hooks_for_clip_schedule(self): - scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {} - for hook in self.hooks: - # only care about WeightHooks, for now - if hook.hook_type == EnumHookType.Weight: - hook_schedule = [] - # if no hook keyframes, assign default value - if len(hook.hook_keyframe.keyframes) == 0: - hook_schedule.append(((0.0, 1.0), None)) - scheduled_hooks[hook] = hook_schedule - continue - # find ranges of values - prev_keyframe = hook.hook_keyframe.keyframes[0] - for keyframe in hook.hook_keyframe.keyframes: - if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength): - hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe)) - prev_keyframe = keyframe - elif keyframe.start_percent == prev_keyframe.start_percent: - prev_keyframe = keyframe - # create final range, assuming last start_percent was not 1.0 - if not math.isclose(prev_keyframe.start_percent, 1.0): - hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe)) - scheduled_hooks[hook] = hook_schedule - # hooks should not have their schedules in a list of tuples - all_ranges: list[tuple[float, float]] = [] - for range_kfs in scheduled_hooks.values(): - for t_range, keyframe in range_kfs: - all_ranges.append(t_range) - # turn list of ranges into boundaries - boundaries_set = set(itertools.chain.from_iterable(all_ranges)) - boundaries_set.add(0.0) - boundaries = sorted(boundaries_set) - real_ranges = [(boundaries[i], boundaries[i + 1]) for i in range(len(boundaries) - 1)] - # with real ranges defined, give appropriate hooks w/ keyframes for each range - scheduled_keyframes: list[tuple[tuple[float,float], list[tuple[WeightHook, HookKeyframe]]]] = [] - for t_range in real_ranges: - hooks_schedule = [] - for hook, val in scheduled_hooks.items(): - keyframe = None - # check if is a keyframe that works for the current t_range - for stored_range, stored_kf in val: - # if stored start is less than current end, then fits - give it assigned keyframe - if stored_range[0] < t_range[1] and stored_range[1] > t_range[0]: - keyframe = stored_kf - break - hooks_schedule.append((hook, keyframe)) - scheduled_keyframes.append((t_range, hooks_schedule)) - return scheduled_keyframes - - def reset(self): - for hook in self.hooks: - hook.reset() - - @staticmethod - def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup': - actual: list[HookGroup] = [] - for group in hooks_list: - if group is not None: - actual.append(group) - if len(actual) < require_count: - raise Exception(f"Need at least {require_count} hooks to combine, but only had {len(actual)}.") - # if no hooks, then return None - if len(actual) == 0: - return None - # if only 1 hook, just return itself without cloning - elif len(actual) == 1: - return actual[0] - final_hook: HookGroup = None - for hook in actual: - if final_hook is None: - final_hook = hook.clone() - else: - final_hook = final_hook.clone_and_combine(hook) - return final_hook - - -class HookKeyframe: - def __init__(self, strength: float, start_percent=0.0, guarantee_steps=1): - self.strength = strength - # scheduling - self.start_percent = float(start_percent) - self.start_t = 999999999.9 - self.guarantee_steps = guarantee_steps - - def clone(self): - c = HookKeyframe(strength=self.strength, - start_percent=self.start_percent, guarantee_steps=self.guarantee_steps) - c.start_t = self.start_t - return c - -class HookKeyframeGroup: - def __init__(self): - self.keyframes: list[HookKeyframe] = [] - self._current_keyframe: HookKeyframe = None - self._current_used_steps = 0 - self._current_index = 0 - self._current_strength = None - self._curr_t = -1. - - # properties shadow those of HookWeightsKeyframe - @property - def strength(self): - if self._current_keyframe is not None: - return self._current_keyframe.strength - return 1.0 - - def reset(self): - self._current_keyframe = None - self._current_used_steps = 0 - self._current_index = 0 - self._current_strength = None - self.curr_t = -1. - self._set_first_as_current() - - def add(self, keyframe: HookKeyframe): - # add to end of list, then sort - self.keyframes.append(keyframe) - self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent") - self._set_first_as_current() - - def _set_first_as_current(self): - if len(self.keyframes) > 0: - self._current_keyframe = self.keyframes[0] - else: - self._current_keyframe = None - - def has_index(self, index: int): - return index >= 0 and index < len(self.keyframes) - - def is_empty(self): - return len(self.keyframes) == 0 - - def clone(self): - c = HookKeyframeGroup() - for keyframe in self.keyframes: - c.keyframes.append(keyframe.clone()) - c._set_first_as_current() - return c - - def initialize_timesteps(self, model: 'BaseModel'): - for keyframe in self.keyframes: - keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) - - def prepare_current_keyframe(self, curr_t: float) -> bool: - if self.is_empty(): - return False - if curr_t == self._curr_t: - return False - prev_index = self._current_index - prev_strength = self._current_strength - # if met guaranteed steps, look for next keyframe in case need to switch - if self._current_used_steps >= self._current_keyframe.guarantee_steps: - # if has next index, loop through and see if need to switch - if self.has_index(self._current_index+1): - for i in range(self._current_index+1, len(self.keyframes)): - eval_c = self.keyframes[i] - # check if start_t is greater or equal to curr_t - # NOTE: t is in terms of sigmas, not percent, so bigger number = earlier step in sampling - if eval_c.start_t >= curr_t: - self._current_index = i - self._current_strength = eval_c.strength - self._current_keyframe = eval_c - self._current_used_steps = 0 - # if guarantee_steps greater than zero, stop searching for other keyframes - if self._current_keyframe.guarantee_steps > 0: - break - # if eval_c is outside the percent range, stop looking further - else: break - # update steps current context is used - self._current_used_steps += 1 - # update current timestep this was performed on - self._curr_t = curr_t - # return True if keyframe changed, False if no change - return prev_index != self._current_index and prev_strength != self._current_strength - - -class InterpolationMethod: - LINEAR = "linear" - EASE_IN = "ease_in" - EASE_OUT = "ease_out" - EASE_IN_OUT = "ease_in_out" - - _LIST = [LINEAR, EASE_IN, EASE_OUT, EASE_IN_OUT] - - @classmethod - def get_weights(cls, num_from: float, num_to: float, length: int, method: str, reverse=False): - diff = num_to - num_from - if method == cls.LINEAR: - weights = torch.linspace(num_from, num_to, length) - elif method == cls.EASE_IN: - index = torch.linspace(0, 1, length) - weights = diff * np.power(index, 2) + num_from - elif method == cls.EASE_OUT: - index = torch.linspace(0, 1, length) - weights = diff * (1 - np.power(1 - index, 2)) + num_from - elif method == cls.EASE_IN_OUT: - index = torch.linspace(0, 1, length) - weights = diff * ((1 - np.cos(index * np.pi)) / 2) + num_from - else: - raise ValueError(f"Unrecognized interpolation method '{method}'.") - if reverse: - weights = weights.flip(dims=(0,)) - return weights - -def get_sorted_list_via_attr(objects: list, attr: str) -> list: - if not objects: - return objects - elif len(objects) <= 1: - return [x for x in objects] - # now that we know we have to sort, do it following these rules: - # a) if objects have same value of attribute, maintain their relative order - # b) perform sorting of the groups of objects with same attributes - unique_attrs = {} - for o in objects: - val_attr = getattr(o, attr) - attr_list: list = unique_attrs.get(val_attr, list()) - attr_list.append(o) - if val_attr not in unique_attrs: - unique_attrs[val_attr] = attr_list - # now that we have the unique attr values grouped together in relative order, sort them by key - sorted_attrs = dict(sorted(unique_attrs.items())) - # now flatten out the dict into a list to return - sorted_list = [] - for object_list in sorted_attrs.values(): - sorted_list.extend(object_list) - return sorted_list - -def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float): - hook_group = HookGroup() - hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip) - hook_group.add(hook) - hook.weights = lora - return hook_group - -def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float, strength_clip: float): - hook_group = HookGroup() - hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip) - hook_group.add(hook) - patches_model = None - patches_clip = None - if weights_model is not None: - patches_model = {} - for key in weights_model: - patches_model[key] = ("model_as_lora", (weights_model[key],)) - if weights_clip is not None: - patches_clip = {} - for key in weights_clip: - patches_clip[key] = ("model_as_lora", (weights_clip[key],)) - hook.weights = patches_model - hook.weights_clip = patches_clip - hook.need_weight_init = False - return hook_group - -def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True): - if model is None: - return None - patches_model: dict[str, torch.Tensor] = model.model.state_dict() - if discard_model_sampling: - # do not include ANY model_sampling components of the model that should act as a patch - for key in list(patches_model.keys()): - if key.startswith("model_sampling"): - patches_model.pop(key, None) - return patches_model - -# NOTE: this function shows how to register weight hooks directly on the ModelPatchers -def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor], - strength_model: float, strength_clip: float): - key_map = {} - if model is not None: - key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) - if clip is not None: - key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map) - - hook_group = HookGroup() - hook = WeightHook() - hook_group.add(hook) - loaded: dict[str] = comfy.lora.load_lora(lora, key_map) - if model is not None: - new_modelpatcher = model.clone() - k = new_modelpatcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_model) - else: - k = () - new_modelpatcher = None - - if clip is not None: - new_clip = clip.clone() - k1 = new_clip.patcher.add_hook_patches(hook=hook, patches=loaded, strength_patch=strength_clip) - else: - k1 = () - new_clip = None - k = set(k) - k1 = set(k1) - for x in loaded: - if (x not in k) and (x not in k1): - print(f"NOT LOADED {x}") - return (new_modelpatcher, new_clip, hook_group) - -def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, HookGroup], cache: dict[tuple[HookGroup, HookGroup], HookGroup]): - hooks_key = 'hooks' - # if hooks only exist in one dict, do what's needed so that it ends up in c_dict - if hooks_key not in values: - return - if hooks_key not in c_dict: - hooks_value = values.get(hooks_key, None) - if hooks_value is not None: - c_dict[hooks_key] = hooks_value - return - # otherwise, need to combine with minimum duplication via cache - hooks_tuple = (c_dict[hooks_key], values[hooks_key]) - cached_hooks = cache.get(hooks_tuple, None) - if cached_hooks is None: - new_hooks = hooks_tuple[0].clone_and_combine(hooks_tuple[1]) - cache[hooks_tuple] = new_hooks - c_dict[hooks_key] = new_hooks - else: - c_dict[hooks_key] = cache[hooks_tuple] - -def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True): - c = [] - hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {} - for t in conditioning: - n = [t[0], t[1].copy()] - for k in values: - if append_hooks and k == 'hooks': - _combine_hooks_from_values(n[1], values, hooks_combine_cache) - else: - n[1][k] = values[k] - c.append(n) - - return c - -def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True): - if hooks is None: - return cond - return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks) - -def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]): - if timestep_range is None: - return cond - return conditioning_set_values(cond, {"start_percent": timestep_range[0], - "end_percent": timestep_range[1]}) - -def set_mask_for_conditioning(cond, mask: torch.Tensor, set_cond_area: str, strength: float): - if mask is None: - return cond - set_area_to_bounds = False - if set_cond_area != 'default': - set_area_to_bounds = True - if len(mask.shape) < 3: - mask = mask.unsqueeze(0) - return conditioning_set_values(cond, {'mask': mask, - 'set_area_to_bounds': set_area_to_bounds, - 'mask_strength': strength}) - -def combine_conditioning(conds: list): - combined_conds = [] - for cond in conds: - combined_conds.extend(cond) - return combined_conds - -def combine_with_new_conds(conds: list, new_conds: list): - combined_conds = [] - for c, new_c in zip(conds, new_conds): - combined_conds.append(combine_conditioning([c, new_c])) - return combined_conds - -def set_conds_props(conds: list, strength: float, set_cond_area: str, - mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): - final_conds = [] - for c in conds: - # first, apply lora_hook to conditioning, if provided - c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks) - # next, apply mask to conditioning - c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area) - # apply timesteps, if present - c = set_timesteps_for_conditioning(cond=c, timestep_range=timesteps_range) - # finally, apply mask to conditioning and store - final_conds.append(c) - return final_conds - -def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default", - mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): - combined_conds = [] - for c, masked_c in zip(conds, new_conds): - # first, apply lora_hook to new conditioning, if provided - masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks) - # next, apply mask to new conditioning, if provided - masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength) - # apply timesteps, if present - masked_c = set_timesteps_for_conditioning(cond=masked_c, timestep_range=timesteps_range) - # finally, combine with existing conditioning and store - combined_conds.append(combine_conditioning([c, masked_c])) - return combined_conds - -def set_default_conds_and_combine(conds: list, new_conds: list, - hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): - combined_conds = [] - for c, new_c in zip(conds, new_conds): - # first, apply lora_hook to new conditioning, if provided - new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks) - # next, add default_cond key to cond so that during sampling, it can be identified - new_c = conditioning_set_values(new_c, {'default': True}) - # apply timesteps, if present - new_c = set_timesteps_for_conditioning(cond=new_c, timestep_range=timesteps_range) - # finally, combine with existing conditioning and store - combined_conds.append(combine_conditioning([c, new_c])) - return combined_conds diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 2838b50c7..a82931627 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -317,6 +317,7 @@ def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta sigma_down = sigmas[i+1] * downstep_ratio alpha_ip1 = 1 - sigmas[i+1] diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 3f7fee708..2902073d5 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -15,7 +15,6 @@ from .util import ( ) from ..attention import SpatialTransformer, SpatialVideoTransformer, default from comfy.ldm.util import exists -import comfy.patcher_extension import comfy.ops ops = comfy.ops.disable_weight_init @@ -48,15 +47,6 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out elif isinstance(layer, Upsample): x = layer(x, output_shape=output_shape) else: - if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]: - found_patched = False - for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]: - if isinstance(layer, class_type): - x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator) - found_patched = True - break - if found_patched: - continue x = layer(x) return x @@ -829,13 +819,6 @@ class UNetModel(nn.Module): ) def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): - return comfy.patcher_extension.WrapperExecutor.new_class_executor( - self._forward, - self, - comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) - ).execute(x, timesteps, context, y, control, transformer_options, **kwargs) - - def _forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. diff --git a/comfy/lora.py b/comfy/lora.py index b6d9a8d04..1080169b1 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -33,7 +33,7 @@ LORA_CLIP_MAP = { } -def load_lora(lora, to_load, log_missing=True): +def load_lora(lora, to_load): patch_dict = {} loaded_keys = set() for x in to_load: @@ -213,10 +213,9 @@ def load_lora(lora, to_load, log_missing=True): patch_dict[to_load[x]] = ("set", (set_weight,)) loaded_keys.add(set_weight_name) - if log_missing: - for x in lora.keys(): - if x not in loaded_keys: - logging.warning("lora key not loaded: {}".format(x)) + for x in lora.keys(): + if x not in loaded_keys: + logging.warning("lora key not loaded: {}".format(x)) return patch_dict @@ -430,7 +429,7 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten return padded_tensor -def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None): +def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): for p in patches: strength = p[0] v = p[1] @@ -472,11 +471,6 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype)) elif patch_type == "set": weight.copy_(v[0]) - elif patch_type == "model_as_lora": - target_weight: torch.Tensor = v[0] - diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \ - comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype) - weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype)) elif patch_type == "lora": #lora/locon mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype) mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype) diff --git a/comfy/model_base.py b/comfy/model_base.py index 8f37af660..c305014a4 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -33,16 +33,12 @@ import comfy.ldm.flux.model import comfy.ldm.lightricks.model import comfy.model_management -import comfy.patcher_extension import comfy.conds import comfy.ops from enum import Enum from . import utils import comfy.latent_formats import math -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from comfy.model_patcher import ModelPatcher class ModelType(Enum): EPS = 1 @@ -99,7 +95,6 @@ class BaseModel(torch.nn.Module): self.model_config = model_config self.manual_cast_dtype = model_config.manual_cast_dtype self.device = device - self.current_patcher: 'ModelPatcher' = None if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: @@ -125,13 +120,6 @@ class BaseModel(torch.nn.Module): self.memory_usage_factor = model_config.memory_usage_factor def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): - return comfy.patcher_extension.WrapperExecutor.new_class_executor( - self._apply_model, - self, - comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.APPLY_MODEL, transformer_options) - ).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs) - - def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): sigma = t xc = self.model_sampling.calculate_input(sigma, x) if c_concat is not None: diff --git a/comfy/model_management.py b/comfy/model_management.py index 46178e8f4..2d00c91e3 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -23,8 +23,6 @@ from comfy.cli_args import args import torch import sys import platform -import weakref -import gc class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -303,27 +301,11 @@ def module_size(module): class LoadedModel: def __init__(self, model): - self._set_model(model) + self.model = model self.device = model.load_device + self.weights_loaded = False self.real_model = None self.currently_used = True - self.model_finalizer = None - self._patcher_finalizer = None - - def _set_model(self, model): - self._model = weakref.ref(model) - if model.parent is not None: - self._parent_model = weakref.ref(model.parent) - self._patcher_finalizer = weakref.finalize(model, self._switch_parent) - - def _switch_parent(self): - model = self._parent_model() - if model is not None: - self._set_model(model) - - @property - def model(self): - return self._model() def model_memory(self): return self.model.model_size() @@ -338,23 +320,32 @@ class LoadedModel: return self.model_memory() def model_load(self, lowvram_model_memory=0, force_patch_weights=False): + patch_model_to = self.device + self.model.model_patches_to(self.device) self.model.model_patches_to(self.model.model_dtype()) - # if self.model.loaded_size() > 0: - use_more_vram = lowvram_model_memory - if use_more_vram == 0: - use_more_vram = 1e32 - self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights) - real_model = self.model.model + load_weights = not self.weights_loaded - if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None: + if self.model.loaded_size() > 0: + use_more_vram = lowvram_model_memory + if use_more_vram == 0: + use_more_vram = 1e32 + self.model_use_more_vram(use_more_vram) + else: + try: + self.real_model = self.model.patch_model(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, load_weights=load_weights, force_patch_weights=force_patch_weights) + except Exception as e: + self.model.unpatch_model(self.model.offload_device) + self.model_unload() + raise e + + if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and self.real_model is not None: with torch.no_grad(): - real_model = ipex.optimize(real_model.eval(), inplace=True, graph_mode=True, concat_linear=True) + self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True) - self.real_model = weakref.ref(real_model) - self.model_finalizer = weakref.finalize(real_model, cleanup_models) - return real_model + self.weights_loaded = True + return self.real_model def should_reload_model(self, force_patch_weights=False): if force_patch_weights and self.model.lowvram_patch_counter() > 0: @@ -367,23 +358,18 @@ class LoadedModel: freed = self.model.partially_unload(self.model.offload_device, memory_to_free) if freed >= memory_to_free: return False - self.model.detach(unpatch_weights) - self.model_finalizer.detach() - self.model_finalizer = None + self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights) + self.model.model_patches_to(self.model.offload_device) + self.weights_loaded = self.weights_loaded and not unpatch_weights self.real_model = None return True - def model_use_more_vram(self, extra_memory, force_patch_weights=False): - return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights) + def model_use_more_vram(self, extra_memory): + return self.model.partially_load(self.device, extra_memory) def __eq__(self, other): return self.model is other.model - def __del__(self): - if self._patcher_finalizer is not None: - self._patcher_finalizer.detach() - - def use_more_memory(extra_memory, loaded_models, device): for m in loaded_models: if m.device == device: @@ -414,8 +400,38 @@ def extra_reserved_memory(): def minimum_inference_memory(): return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() +def unload_model_clones(model, unload_weights_only=True, force_unload=True): + to_unload = [] + for i in range(len(current_loaded_models)): + if model.is_clone(current_loaded_models[i].model): + to_unload = [i] + to_unload + + if len(to_unload) == 0: + return True + + same_weights = 0 + for i in to_unload: + if model.clone_has_same_weights(current_loaded_models[i].model): + same_weights += 1 + + if same_weights == len(to_unload): + unload_weight = False + else: + unload_weight = True + + if not force_unload: + if unload_weights_only and unload_weight == False: + return None + else: + unload_weight = True + + for i in to_unload: + logging.debug("unload clone {} {}".format(i, unload_weight)) + current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight) + + return unload_weight + def free_memory(memory_required, device, keep_loaded=[]): - cleanup_models_gc() unloaded_model = [] can_unload = [] unloaded_models = [] @@ -452,7 +468,6 @@ def free_memory(memory_required, device, keep_loaded=[]): return unloaded_models def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): - cleanup_models_gc() global vram_state inference_memory = minimum_inference_memory() @@ -465,9 +480,11 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu models = set(models) models_to_load = [] - + models_already_loaded = [] for x in models: loaded_model = LoadedModel(x) + loaded = None + try: loaded_model_index = current_loaded_models.index(loaded_model) except: @@ -475,35 +492,51 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu if loaded_model_index is not None: loaded = current_loaded_models[loaded_model_index] - loaded.currently_used = True - models_to_load.append(loaded) - else: + if loaded.should_reload_model(force_patch_weights=force_patch_weights): #TODO: cleanup this model reload logic + current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True) + loaded = None + else: + loaded.currently_used = True + models_already_loaded.append(loaded) + + if loaded is None: if hasattr(x, "model"): logging.info(f"Requested to load {x.model.__class__.__name__}") models_to_load.append(loaded_model) - for loaded_model in models_to_load: - to_unload = [] - for i in range(len(current_loaded_models)): - if loaded_model.model.is_clone(current_loaded_models[i].model): - to_unload = [i] + to_unload - for i in to_unload: - current_loaded_models.pop(i).model.detach(unpatch_all=False) + if len(models_to_load) == 0: + devs = set(map(lambda a: a.device, models_already_loaded)) + for d in devs: + if d != torch.device("cpu"): + free_memory(extra_mem + offloaded_memory(models_already_loaded, d), d, models_already_loaded) + free_mem = get_free_memory(d) + if free_mem < minimum_memory_required: + logging.info("Unloading models for lowram load.") #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed. + models_to_load = free_memory(minimum_memory_required, d) + logging.info("{} models unloaded.".format(len(models_to_load))) + else: + use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d) + if len(models_to_load) == 0: + return + + logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}") total_memory_required = {} for loaded_model in models_to_load: + unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) - for device in total_memory_required: - if device != torch.device("cpu"): - free_memory(total_memory_required[device] * 1.1 + extra_mem, device) + for loaded_model in models_already_loaded: + total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) + + for loaded_model in models_to_load: + weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded + if weights_unloaded is not None: + loaded_model.weights_loaded = not weights_unloaded for device in total_memory_required: if device != torch.device("cpu"): - free_mem = get_free_memory(device) - if free_mem < minimum_memory_required: - models_l = free_memory(minimum_memory_required, device) - logging.info("{} models unloaded.".format(len(models_l))) + free_memory(total_memory_required[device] * 1.1 + extra_mem, device, models_already_loaded) for loaded_model in models_to_load: model = loaded_model.model @@ -525,8 +558,17 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights) current_loaded_models.insert(0, loaded_model) + + + devs = set(map(lambda a: a.device, models_already_loaded)) + for d in devs: + if d != torch.device("cpu"): + free_mem = get_free_memory(d) + if free_mem > minimum_memory_required: + use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d) return + def load_model_gpu(model): return load_models_gpu([model]) @@ -540,35 +582,21 @@ def loaded_models(only_currently_used=False): output.append(m.model) return output - -def cleanup_models_gc(): - do_gc = False - for i in range(len(current_loaded_models)): - cur = current_loaded_models[i] - if cur.real_model() is not None and cur.model is None: - logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__)) - do_gc = True - break - - if do_gc: - gc.collect() - soft_empty_cache() - - for i in range(len(current_loaded_models)): - cur = current_loaded_models[i] - if cur.real_model() is not None and cur.model is None: - logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__)) - - - -def cleanup_models(): +def cleanup_models(keep_clone_weights_loaded=False): to_delete = [] for i in range(len(current_loaded_models)): - if current_loaded_models[i].real_model() is None: - to_delete = [i] + to_delete + #TODO: very fragile function needs improvement + num_refs = sys.getrefcount(current_loaded_models[i].model) + if num_refs <= 2: + if not keep_clone_weights_loaded: + to_delete = [i] + to_delete + #TODO: find a less fragile way to do this. + elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model + to_delete = [i] + to_delete for i in to_delete: x = current_loaded_models.pop(i) + x.model_unload() del x def dtype_size(dtype): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 4ae3ad25d..f53f10749 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -16,8 +16,6 @@ along with this program. If not, see . """ -from __future__ import annotations -from typing import Optional, Callable import torch import copy import inspect @@ -30,9 +28,6 @@ import comfy.utils import comfy.float import comfy.model_management import comfy.lora -import comfy.hooks -import comfy.patcher_extension -from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection from comfy.comfy_types import UnetWrapperFunction def string_to_seed(data): @@ -81,17 +76,6 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_ model_options["disable_cfg1_optimization"] = True return model_options -def create_model_options_clone(orig_model_options: dict): - return comfy.patcher_extension.copy_nested_dicts(orig_model_options) - -def create_hook_patches_clone(orig_hook_patches): - new_hook_patches = {} - for hook_ref in orig_hook_patches: - new_hook_patches[hook_ref] = {} - for k in orig_hook_patches[hook_ref]: - new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:] - return new_hook_patches - def wipe_lowvram_weight(m): if hasattr(m, "prev_comfy_cast_weights"): m.comfy_cast_weights = m.prev_comfy_cast_weights @@ -135,49 +119,6 @@ def get_key_weight(model, key): return weight, set_func, convert_func -class AutoPatcherEjector: - def __init__(self, model: 'ModelPatcher', skip_and_inject_on_exit_only=False): - self.model = model - self.was_injected = False - self.prev_skip_injection = False - self.skip_and_inject_on_exit_only = skip_and_inject_on_exit_only - - def __enter__(self): - self.was_injected = False - self.prev_skip_injection = self.model.skip_injection - if self.skip_and_inject_on_exit_only: - self.model.skip_injection = True - if self.model.is_injected: - self.model.eject_model() - self.was_injected = True - - def __exit__(self, *args): - if self.skip_and_inject_on_exit_only: - self.model.skip_injection = self.prev_skip_injection - self.model.inject_model() - if self.was_injected and not self.model.skip_injection: - self.model.inject_model() - self.model.skip_injection = self.prev_skip_injection - -class MemoryCounter: - def __init__(self, initial: int, minimum=0): - self.value = initial - self.minimum = minimum - # TODO: add a safe limit besides 0 - - def use(self, weight: torch.Tensor): - weight_size = weight.nelement() * weight.element_size() - if self.is_useable(weight_size): - self.decrement(weight_size) - return True - return False - - def is_useable(self, used: int): - return self.value - used > self.minimum - - def decrement(self, used: int): - self.value -= used - class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): self.size = size @@ -198,25 +139,6 @@ class ModelPatcher: self.offload_device = offload_device self.weight_inplace_update = weight_inplace_update self.patches_uuid = uuid.uuid4() - self.parent = None - - self.attachments: dict[str] = {} - self.additional_models: dict[str, list[ModelPatcher]] = {} - self.callbacks: dict[str, dict[str, list[Callable]]] = CallbacksMP.init_callbacks() - self.wrappers: dict[str, dict[str, list[Callable]]] = WrappersMP.init_wrappers() - - self.is_injected = False - self.skip_injection = False - self.injections: dict[str, list[PatcherInjection]] = {} - - self.hook_patches: dict[comfy.hooks._HookRef] = {} - self.hook_patches_backup: dict[comfy.hooks._HookRef] = {} - self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {} - self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {} - self.current_hooks: Optional[comfy.hooks.HookGroup] = None - self.forced_hooks: Optional[comfy.hooks.HookGroup] = None # NOTE: only used for CLIP at this time - self.is_clip = False - self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed if not hasattr(self.model, 'model_loaded_weight_memory'): self.model.model_loaded_weight_memory = 0 @@ -227,9 +149,6 @@ class ModelPatcher: if not hasattr(self.model, 'model_lowvram'): self.model.model_lowvram = False - if not hasattr(self.model, 'current_weight_patches_uuid'): - self.model.current_weight_patches_uuid = None - def model_size(self): if self.size > 0: return self.size @@ -253,48 +172,6 @@ class ModelPatcher: n.model_options = copy.deepcopy(self.model_options) n.backup = self.backup n.object_patches_backup = self.object_patches_backup - n.parent = self - - # attachments - n.attachments = {} - for k in self.attachments: - if hasattr(self.attachments[k], "on_model_patcher_clone"): - n.attachments[k] = self.attachments[k].on_model_patcher_clone() - else: - n.attachments[k] = self.attachments[k] - # additional models - for k, c in self.additional_models.items(): - n.additional_models[k] = [x.clone() for x in c] - # callbacks - for k, c in self.callbacks.items(): - n.callbacks[k] = {} - for k1, c1 in c.items(): - n.callbacks[k][k1] = c1.copy() - # sample wrappers - for k, w in self.wrappers.items(): - n.wrappers[k] = {} - for k1, w1 in w.items(): - n.wrappers[k][k1] = w1.copy() - # injection - n.is_injected = self.is_injected - n.skip_injection = self.skip_injection - for k, i in self.injections.items(): - n.injections[k] = i.copy() - # hooks - n.hook_patches = create_hook_patches_clone(self.hook_patches) - n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) - for group in self.cached_hook_patches: - n.cached_hook_patches[group] = {} - for k in self.cached_hook_patches[group]: - n.cached_hook_patches[group][k] = self.cached_hook_patches[group][k] - n.hook_backup = self.hook_backup - n.current_hooks = self.current_hooks.clone() if self.current_hooks else self.current_hooks - n.forced_hooks = self.forced_hooks.clone() if self.forced_hooks else self.forced_hooks - n.is_clip = self.is_clip - n.hook_mode = self.hook_mode - - for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE): - callback(self, n) return n def is_clone(self, other): @@ -302,29 +179,10 @@ class ModelPatcher: return True return False - def clone_has_same_weights(self, clone: 'ModelPatcher'): + def clone_has_same_weights(self, clone): if not self.is_clone(clone): return False - if self.current_hooks != clone.current_hooks: - return False - if self.forced_hooks != clone.forced_hooks: - return False - if self.hook_patches.keys() != clone.hook_patches.keys(): - return False - if self.attachments.keys() != clone.attachments.keys(): - return False - if self.additional_models.keys() != clone.additional_models.keys(): - return False - for key in self.callbacks: - if len(self.callbacks[key]) != len(clone.callbacks[key]): - return False - for key in self.wrappers: - if len(self.wrappers[key]) != len(clone.wrappers[key]): - return False - if self.injections.keys() != clone.injections.keys(): - return False - if len(self.patches) == 0 and len(clone.patches) == 0: return True @@ -393,12 +251,6 @@ class ModelPatcher: def set_model_output_block_patch(self, patch): self.set_model_patch(patch, "output_block_patch") - def set_model_emb_patch(self, patch): - self.set_model_patch(patch, "emb_patch") - - def set_model_forward_timestep_embed_patch(self, patch): - self.set_model_patch(patch, "forward_timestep_embed_patch") - def add_object_patch(self, name, obj): self.object_patches[name] = obj @@ -437,28 +289,27 @@ class ModelPatcher: return self.model.get_dtype() def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): - with self.use_ejected(): - p = set() - model_sd = self.model.state_dict() - for k in patches: - offset = None - function = None - if isinstance(k, str): - key = k - else: - offset = k[1] - key = k[0] - if len(k) > 2: - function = k[2] + p = set() + model_sd = self.model.state_dict() + for k in patches: + offset = None + function = None + if isinstance(k, str): + key = k + else: + offset = k[1] + key = k[0] + if len(k) > 2: + function = k[2] - if key in model_sd: - p.add(k) - current_patches = self.patches.get(key, []) - current_patches.append((strength_patch, patches[k], strength_model, offset, function)) - self.patches[key] = current_patches + if key in model_sd: + p.add(k) + current_patches = self.patches.get(key, []) + current_patches.append((strength_patch, patches[k], strength_model, offset, function)) + self.patches[key] = current_patches - self.patches_uuid = uuid.uuid4() - return list(p) + self.patches_uuid = uuid.uuid4() + return list(p) def get_key_patches(self, filter_prefix=None): model_sd = self.model_state_dict() @@ -468,12 +319,9 @@ class ModelPatcher: if not k.startswith(filter_prefix): continue bk = self.backup.get(k, None) - hbk = self.hook_backup.get(k, None) weight, set_func, convert_func = get_key_weight(self.model, k) if bk is not None: weight = bk.weight - if hbk is not None: - weight = hbk[0] if convert_func is None: convert_func = lambda a, **kwargs: a @@ -484,14 +332,13 @@ class ModelPatcher: return p def model_state_dict(self, filter_prefix=None): - with self.use_ejected(): - sd = self.model.state_dict() - keys = list(sd.keys()) - if filter_prefix is not None: - for k in keys: - if not k.startswith(filter_prefix): - sd.pop(k) - return sd + sd = self.model.state_dict() + keys = list(sd.keys()) + if filter_prefix is not None: + for k in keys: + if not k.startswith(filter_prefix): + sd.pop(k) + return sd def patch_weight_to_device(self, key, device_to=None, inplace_update=False): if key not in self.patches: @@ -536,117 +383,105 @@ class ModelPatcher: return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): - with self.use_ejected(): - self.unpatch_hooks() - mem_counter = 0 - patch_counter = 0 - lowvram_counter = 0 - loading = self._load_list() + mem_counter = 0 + patch_counter = 0 + lowvram_counter = 0 + loading = self._load_list() - load_completely = [] - loading.sort(reverse=True) - for x in loading: - n = x[1] - m = x[2] - params = x[3] - module_mem = x[0] + load_completely = [] + loading.sort(reverse=True) + for x in loading: + n = x[1] + m = x[2] + params = x[3] + module_mem = x[0] - lowvram_weight = False + lowvram_weight = False - if not full_load and hasattr(m, "comfy_cast_weights"): - if mem_counter + module_mem >= lowvram_model_memory: - lowvram_weight = True - lowvram_counter += 1 - if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed - continue - - weight_key = "{}.weight".format(n) - bias_key = "{}.bias".format(n) - - if lowvram_weight: - if weight_key in self.patches: - if force_patch_weights: - self.patch_weight_to_device(weight_key) - else: - m.weight_function = LowVramPatch(weight_key, self.patches) - patch_counter += 1 - if bias_key in self.patches: - if force_patch_weights: - self.patch_weight_to_device(bias_key) - else: - m.bias_function = LowVramPatch(bias_key, self.patches) - patch_counter += 1 - - m.prev_comfy_cast_weights = m.comfy_cast_weights - m.comfy_cast_weights = True - else: - if hasattr(m, "comfy_cast_weights"): - if m.comfy_cast_weights: - wipe_lowvram_weight(m) - - if full_load or mem_counter + module_mem < lowvram_model_memory: - mem_counter += module_mem - load_completely.append((module_mem, n, m, params)) - - load_completely.sort(reverse=True) - for x in load_completely: - n = x[1] - m = x[2] - params = x[3] - if hasattr(m, "comfy_patched_weights"): - if m.comfy_patched_weights == True: + if not full_load and hasattr(m, "comfy_cast_weights"): + if mem_counter + module_mem >= lowvram_model_memory: + lowvram_weight = True + lowvram_counter += 1 + if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed continue - for param in params: - self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to) + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) - logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) - m.comfy_patched_weights = True + if lowvram_weight: + if weight_key in self.patches: + if force_patch_weights: + self.patch_weight_to_device(weight_key) + else: + m.weight_function = LowVramPatch(weight_key, self.patches) + patch_counter += 1 + if bias_key in self.patches: + if force_patch_weights: + self.patch_weight_to_device(bias_key) + else: + m.bias_function = LowVramPatch(bias_key, self.patches) + patch_counter += 1 - for x in load_completely: - x[2].to(device_to) - - if lowvram_counter > 0: - logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) - self.model.model_lowvram = True + m.prev_comfy_cast_weights = m.comfy_cast_weights + m.comfy_cast_weights = True else: - logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) - self.model.model_lowvram = False - if full_load: - self.model.to(device_to) - mem_counter = self.model_size() + if hasattr(m, "comfy_cast_weights"): + if m.comfy_cast_weights: + wipe_lowvram_weight(m) - self.model.lowvram_patch_counter += patch_counter - self.model.device = device_to - self.model.model_loaded_weight_memory = mem_counter - self.model.current_weight_patches_uuid = self.patches_uuid + if full_load or mem_counter + module_mem < lowvram_model_memory: + mem_counter += module_mem + load_completely.append((module_mem, n, m, params)) - for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD): - callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load) + load_completely.sort(reverse=True) + for x in load_completely: + n = x[1] + m = x[2] + params = x[3] + if hasattr(m, "comfy_patched_weights"): + if m.comfy_patched_weights == True: + continue - self.apply_hooks(self.forced_hooks, force_apply=True) + for param in params: + self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to) + + logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) + m.comfy_patched_weights = True + + for x in load_completely: + x[2].to(device_to) + + if lowvram_counter > 0: + logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) + self.model.model_lowvram = True + else: + logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) + self.model.model_lowvram = False + if full_load: + self.model.to(device_to) + mem_counter = self.model_size() + + self.model.lowvram_patch_counter += patch_counter + self.model.device = device_to + self.model.model_loaded_weight_memory = mem_counter def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): - with self.use_ejected(): - for k in self.object_patches: - old = comfy.utils.set_attr(self.model, k, self.object_patches[k]) - if k not in self.object_patches_backup: - self.object_patches_backup[k] = old + for k in self.object_patches: + old = comfy.utils.set_attr(self.model, k, self.object_patches[k]) + if k not in self.object_patches_backup: + self.object_patches_backup[k] = old - if lowvram_model_memory == 0: - full_load = True - else: - full_load = False + if lowvram_model_memory == 0: + full_load = True + else: + full_load = False - if load_weights: - self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load) - self.inject_model() + if load_weights: + self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load) return self.model def unpatch_model(self, device_to=None, unpatch_weights=True): - self.eject_model() if unpatch_weights: - self.unpatch_hooks() if self.model.model_lowvram: for m in self.model.modules(): wipe_lowvram_weight(m) @@ -663,7 +498,6 @@ class ModelPatcher: else: comfy.utils.set_attr_param(self.model, k, bk.weight) - self.model.current_weight_patches_uuid = None self.backup.clear() if device_to is not None: @@ -682,92 +516,69 @@ class ModelPatcher: self.object_patches_backup.clear() def partially_unload(self, device_to, memory_to_free=0): - with self.use_ejected(): - memory_freed = 0 - patch_counter = 0 - unload_list = self._load_list() - unload_list.sort() - for unload in unload_list: - if memory_to_free < memory_freed: - break - module_mem = unload[0] - n = unload[1] - m = unload[2] - params = unload[3] + memory_freed = 0 + patch_counter = 0 + unload_list = self._load_list() + unload_list.sort() + for unload in unload_list: + if memory_to_free < memory_freed: + break + module_mem = unload[0] + n = unload[1] + m = unload[2] + params = unload[3] - lowvram_possible = hasattr(m, "comfy_cast_weights") - if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: - move_weight = True - for param in params: - key = "{}.{}".format(n, param) - bk = self.backup.get(key, None) - if bk is not None: - if not lowvram_possible: - move_weight = False - break + lowvram_possible = hasattr(m, "comfy_cast_weights") + if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: + move_weight = True + for param in params: + key = "{}.{}".format(n, param) + bk = self.backup.get(key, None) + if bk is not None: + if not lowvram_possible: + move_weight = False + break - if bk.inplace_update: - comfy.utils.copy_to_param(self.model, key, bk.weight) - else: - comfy.utils.set_attr_param(self.model, key, bk.weight) - self.backup.pop(key) - - weight_key = "{}.weight".format(n) - bias_key = "{}.bias".format(n) - if move_weight: - m.to(device_to) - if lowvram_possible: - if weight_key in self.patches: - m.weight_function = LowVramPatch(weight_key, self.patches) - patch_counter += 1 - if bias_key in self.patches: - m.bias_function = LowVramPatch(bias_key, self.patches) - patch_counter += 1 + if bk.inplace_update: + comfy.utils.copy_to_param(self.model, key, bk.weight) + else: + comfy.utils.set_attr_param(self.model, key, bk.weight) + self.backup.pop(key) - m.prev_comfy_cast_weights = m.comfy_cast_weights - m.comfy_cast_weights = True - m.comfy_patched_weights = False - memory_freed += module_mem - logging.debug("freed {}".format(n)) + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) + if move_weight: + m.to(device_to) + if lowvram_possible: + if weight_key in self.patches: + m.weight_function = LowVramPatch(weight_key, self.patches) + patch_counter += 1 + if bias_key in self.patches: + m.bias_function = LowVramPatch(bias_key, self.patches) + patch_counter += 1 - self.model.model_lowvram = True - self.model.lowvram_patch_counter += patch_counter - self.model.model_loaded_weight_memory -= memory_freed - return memory_freed + m.prev_comfy_cast_weights = m.comfy_cast_weights + m.comfy_cast_weights = True + m.comfy_patched_weights = False + memory_freed += module_mem + logging.debug("freed {}".format(n)) - def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): - with self.use_ejected(skip_and_inject_on_exit_only=True): - unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights) - # TODO: force_patch_weights should not unload + reload full model - used = self.model.model_loaded_weight_memory - self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights) - if unpatch_weights: - extra_memory += (used - self.model.model_loaded_weight_memory) + self.model.model_lowvram = True + self.model.lowvram_patch_counter += patch_counter + self.model.model_loaded_weight_memory -= memory_freed + return memory_freed - self.patch_model(load_weights=False) - full_load = False - if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0: - self.apply_hooks(self.forced_hooks, force_apply=True) - return 0 - if self.model.model_loaded_weight_memory + extra_memory > self.model_size(): - full_load = True - current_used = self.model.model_loaded_weight_memory - try: - self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load) - except Exception as e: - self.detach() - raise e - - return self.model.model_loaded_weight_memory - current_used - - def detach(self, unpatch_all=True): - self.eject_model() - self.model_patches_to(self.offload_device) - if unpatch_all: - self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all) - for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH): - callback(self, unpatch_all) - return self.model + def partially_load(self, device_to, extra_memory=0): + self.unpatch_model(unpatch_weights=False) + self.patch_model(load_weights=False) + full_load = False + if self.model.model_lowvram == False: + return 0 + if self.model.model_loaded_weight_memory + extra_memory > self.model_size(): + full_load = True + current_used = self.model.model_loaded_weight_memory + self.load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load) + return self.model.model_loaded_weight_memory - current_used def current_loaded_device(self): return self.model.device @@ -775,346 +586,3 @@ class ModelPatcher: def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32): print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead") return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype) - - def cleanup(self): - self.clean_hooks() - if hasattr(self.model, "current_patcher"): - self.model.current_patcher = None - for callback in self.get_all_callbacks(CallbacksMP.ON_CLEANUP): - callback(self) - - def add_callback(self, call_type: str, callback: Callable): - self.add_callback_with_key(call_type, None, callback) - - def add_callback_with_key(self, call_type: str, key: str, callback: Callable): - c = self.callbacks.setdefault(call_type, {}).setdefault(key, []) - c.append(callback) - - def remove_callbacks_with_key(self, call_type: str, key: str): - c = self.callbacks.get(call_type, {}) - if key in c: - c.pop(key) - - def get_callbacks(self, call_type: str, key: str): - return self.callbacks.get(call_type, {}).get(key, []) - - def get_all_callbacks(self, call_type: str): - c_list = [] - for c in self.callbacks.get(call_type, {}).values(): - c_list.extend(c) - return c_list - - def add_wrapper(self, wrapper_type: str, wrapper: Callable): - self.add_wrapper_with_key(wrapper_type, None, wrapper) - - def add_wrapper_with_key(self, wrapper_type: str, key: str, wrapper: Callable): - w = self.wrappers.setdefault(wrapper_type, {}).setdefault(key, []) - w.append(wrapper) - - def remove_wrappers_with_key(self, wrapper_type: str, key: str): - w = self.wrappers.get(wrapper_type, {}) - if key in w: - w.pop(key) - - def get_wrappers(self, wrapper_type: str, key: str): - return self.wrappers.get(wrapper_type, {}).get(key, []) - - def get_all_wrappers(self, wrapper_type: str): - w_list = [] - for w in self.wrappers.get(wrapper_type, {}).values(): - w_list.extend(w) - return w_list - - def set_attachments(self, key: str, attachment): - self.attachments[key] = attachment - - def remove_attachments(self, key: str): - if key in self.attachments: - self.attachments.pop(key) - - def get_attachment(self, key: str): - return self.attachments.get(key, None) - - def set_injections(self, key: str, injections: list[PatcherInjection]): - self.injections[key] = injections - - def remove_injections(self, key: str): - if key in self.injections: - self.injections.pop(key) - - def set_additional_models(self, key: str, models: list['ModelPatcher']): - self.additional_models[key] = models - - def remove_additional_models(self, key: str): - if key in self.additional_models: - self.additional_models.pop(key) - - def get_additional_models_with_key(self, key: str): - return self.additional_models.get(key, []) - - def get_additional_models(self): - all_models = [] - for models in self.additional_models.values(): - all_models.extend(models) - return all_models - - def get_nested_additional_models(self): - def _evaluate_sub_additional_models(prev_models: list[ModelPatcher], cache_set: set[ModelPatcher]): - '''Make sure circular references do not cause infinite recursion.''' - next_models = [] - for model in prev_models: - candidates = model.get_additional_models() - for c in candidates: - if c not in cache_set: - next_models.append(c) - cache_set.add(c) - if len(next_models) == 0: - return prev_models - return prev_models + _evaluate_sub_additional_models(next_models, cache_set) - - all_models = self.get_additional_models() - models_set = set(all_models) - real_all_models = _evaluate_sub_additional_models(prev_models=all_models, cache_set=models_set) - return real_all_models - - def use_ejected(self, skip_and_inject_on_exit_only=False): - return AutoPatcherEjector(self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only) - - def inject_model(self): - if self.is_injected or self.skip_injection: - return - for injections in self.injections.values(): - for inj in injections: - inj.inject(self) - self.is_injected = True - if self.is_injected: - for callback in self.get_all_callbacks(CallbacksMP.ON_INJECT_MODEL): - callback(self) - - def eject_model(self): - if not self.is_injected: - return - for injections in self.injections.values(): - for inj in injections: - inj.eject(self) - self.is_injected = False - for callback in self.get_all_callbacks(CallbacksMP.ON_EJECT_MODEL): - callback(self) - - def pre_run(self): - if hasattr(self.model, "current_patcher"): - self.model.current_patcher = self - for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN): - callback(self) - - def prepare_state(self, timestep): - for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE): - callback(self, timestep) - - def restore_hook_patches(self): - if len(self.hook_patches_backup) > 0: - self.hook_patches = self.hook_patches_backup - self.hook_patches_backup = {} - - def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode): - self.hook_mode = hook_mode - - def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup): - curr_t = t[0] - reset_current_hooks = False - for hook in hook_group.hooks: - changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t) - # if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref; - # this will cause the weights to be recalculated when sampling - if changed: - # reset current_hooks if contains hook that changed - if self.current_hooks is not None: - for current_hook in self.current_hooks.hooks: - if current_hook == hook: - reset_current_hooks = True - break - for cached_group in list(self.cached_hook_patches.keys()): - if cached_group.contains(hook): - self.cached_hook_patches.pop(cached_group) - if reset_current_hooks: - self.patch_hooks(None) - - def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target: comfy.hooks.EnumWeightTarget, model_options: dict=None): - self.restore_hook_patches() - registered_hooks: list[comfy.hooks.Hook] = [] - # handle WrapperHooks, if model_options provided - if model_options is not None: - for hook in hooks_dict.get(comfy.hooks.EnumHookType.Wrappers, {}): - hook.add_hook_patches(self, model_options, target, registered_hooks) - # handle WeightHooks - weight_hooks_to_register: list[comfy.hooks.WeightHook] = [] - for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}): - if hook.hook_ref not in self.hook_patches: - weight_hooks_to_register.append(hook) - if len(weight_hooks_to_register) > 0: - # clone hook_patches to become backup so that any non-dynamic hooks will return to their original state - self.hook_patches_backup = create_hook_patches_clone(self.hook_patches) - for hook in weight_hooks_to_register: - hook.add_hook_patches(self, model_options, target, registered_hooks) - for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES): - callback(self, hooks_dict, target) - - def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0): - with self.use_ejected(): - # NOTE: this mirrors behavior of add_patches func - current_hook_patches: dict[str,list] = self.hook_patches.get(hook.hook_ref, {}) - p = set() - model_sd = self.model.state_dict() - for k in patches: - offset = None - function = None - if isinstance(k, str): - key = k - else: - offset = k[1] - key = k[0] - if len(k) > 2: - function = k[2] - - if key in model_sd: - p.add(k) - current_patches: list[tuple] = current_hook_patches.get(key, []) - current_patches.append((strength_patch, patches[k], strength_model, offset, function)) - current_hook_patches[key] = current_patches - self.hook_patches[hook.hook_ref] = current_hook_patches - # since should care about these patches too to determine if same model, reroll patches_uuid - self.patches_uuid = uuid.uuid4() - return list(p) - - def get_combined_hook_patches(self, hooks: comfy.hooks.HookGroup): - # combined_patches will contain weights of all relevant hooks, per key - combined_patches = {} - if hooks is not None: - for hook in hooks.hooks: - hook_patches: dict = self.hook_patches.get(hook.hook_ref, {}) - for key in hook_patches.keys(): - current_patches: list[tuple] = combined_patches.get(key, []) - if math.isclose(hook.strength, 1.0): - current_patches.extend(hook_patches[key]) - else: - # patches are stored as tuples: (strength_patch, (tuple_with_weights,), strength_model) - for patch in hook_patches[key]: - new_patch = list(patch) - new_patch[0] *= hook.strength - current_patches.append(tuple(new_patch)) - combined_patches[key] = current_patches - return combined_patches - - def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False): - # TODO: return transformer_options dict with any additions from hooks - if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)): - return {} - self.patch_hooks(hooks=hooks) - for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS): - callback(self, hooks) - return {} - - def patch_hooks(self, hooks: comfy.hooks.HookGroup): - with self.use_ejected(): - self.unpatch_hooks() - if hooks is not None: - model_sd_keys = list(self.model_state_dict().keys()) - memory_counter = None - if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: - # TODO: minimum_counter should have a minimum that conforms to loaded model requirements - memory_counter = MemoryCounter(initial=comfy.model_management.get_free_memory(self.load_device), - minimum=comfy.model_management.minimum_inference_memory()*2) - # if have cached weights for hooks, use it - cached_weights = self.cached_hook_patches.get(hooks, None) - if cached_weights is not None: - for key in cached_weights: - if key not in model_sd_keys: - print(f"WARNING cached hook could not patch. key does not exist in model: {key}") - continue - self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter) - else: - relevant_patches = self.get_combined_hook_patches(hooks=hooks) - original_weights = None - if len(relevant_patches) > 0: - original_weights = self.get_key_patches() - for key in relevant_patches: - if key not in model_sd_keys: - print(f"WARNING cached hook would not patch. key does not exist in model: {key}") - continue - self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights, - memory_counter=memory_counter) - self.current_hooks = hooks - - def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter): - if key not in self.hook_backup: - weight: torch.Tensor = comfy.utils.get_attr(self.model, key) - target_device = self.offload_device - if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: - used = memory_counter.use(weight) - if used: - target_device = weight.device - self.hook_backup[key] = (weight.to(device=target_device, copy=True), weight.device) - comfy.utils.copy_to_param(self.model, key, cached_weights[key][0].to(device=cached_weights[key][1])) - - def clear_cached_hook_weights(self): - self.cached_hook_patches.clear() - self.patch_hooks(None) - - def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter): - if key not in combined_patches: - return - - weight, set_func, convert_func = get_key_weight(self.model, key) - weight: torch.Tensor - if key not in self.hook_backup: - target_device = self.offload_device - if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: - used = memory_counter.use(weight) - if used: - target_device = weight.device - self.hook_backup[key] = (weight.to(device=target_device, copy=True), weight.device) - # TODO: properly handle LowVramPatch, if it ends up an issue - temp_weight = comfy.model_management.cast_to_device(weight, weight.device, torch.float32, copy=True) - if convert_func is not None: - temp_weight = convert_func(temp_weight, inplace=True) - - out_weight = comfy.lora.calculate_weight(combined_patches[key], - temp_weight, - key, original_weights=original_weights) - del original_weights[key] - if set_func is None: - out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) - comfy.utils.copy_to_param(self.model, key, out_weight) - else: - set_func(out_weight, inplace_update=True, seed=string_to_seed(key)) - if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: - # TODO: disable caching if not enough system RAM to do so - target_device = self.offload_device - used = memory_counter.use(weight) - if used: - target_device = weight.device - self.cached_hook_patches.setdefault(hooks, {}) - self.cached_hook_patches[hooks][key] = (out_weight.to(device=target_device, copy=False), weight.device) - del temp_weight - del out_weight - del weight - - def unpatch_hooks(self) -> None: - with self.use_ejected(): - if len(self.hook_backup) == 0: - self.current_hooks = None - return - keys = list(self.hook_backup.keys()) - for k in keys: - comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1])) - - self.hook_backup.clear() - self.current_hooks = None - - def clean_hooks(self): - self.unpatch_hooks() - self.clear_cached_hook_weights() - - def __del__(self): - self.detach(unpatch_all=False) - diff --git a/comfy/patcher_extension.py b/comfy/patcher_extension.py deleted file mode 100644 index 514469185..000000000 --- a/comfy/patcher_extension.py +++ /dev/null @@ -1,156 +0,0 @@ -from __future__ import annotations -from typing import Callable - -class CallbacksMP: - ON_CLONE = "on_clone" - ON_LOAD = "on_load_after" - ON_DETACH = "on_detach_after" - ON_CLEANUP = "on_cleanup" - ON_PRE_RUN = "on_pre_run" - ON_PREPARE_STATE = "on_prepare_state" - ON_APPLY_HOOKS = "on_apply_hooks" - ON_REGISTER_ALL_HOOK_PATCHES = "on_register_all_hook_patches" - ON_INJECT_MODEL = "on_inject_model" - ON_EJECT_MODEL = "on_eject_model" - - # callbacks dict is in the format: - # {"call_type": {"key": [Callable1, Callable2, ...]} } - @classmethod - def init_callbacks(cls) -> dict[str, dict[str, list[Callable]]]: - return {} - -def add_callback(call_type: str, callback: Callable, transformer_options: dict, is_model_options=False): - add_callback_with_key(call_type, None, callback, transformer_options, is_model_options) - -def add_callback_with_key(call_type: str, key: str, callback: Callable, transformer_options: dict, is_model_options=False): - if is_model_options: - transformer_options = transformer_options.setdefault("transformer_options", {}) - callbacks: dict[str, dict[str, list]] = transformer_options.setdefault("callbacks", {}) - c = callbacks.setdefault(call_type, {}).setdefault(key, []) - c.append(callback) - -def get_callbacks_with_key(call_type: str, key: str, transformer_options: dict, is_model_options=False): - if is_model_options: - transformer_options = transformer_options.get("transformer_options", {}) - c_list = [] - callbacks: dict[str, list] = transformer_options.get("callbacks", {}) - c_list.extend(callbacks.get(call_type, {}).get(key, [])) - return c_list - -def get_all_callbacks(call_type: str, transformer_options: dict, is_model_options=False): - if is_model_options: - transformer_options = transformer_options.get("transformer_options", {}) - c_list = [] - callbacks: dict[str, list] = transformer_options.get("callbacks", {}) - for c in callbacks.get(call_type, {}).values(): - c_list.extend(c) - return c_list - -class WrappersMP: - OUTER_SAMPLE = "outer_sample" - SAMPLER_SAMPLE = "sampler_sample" - CALC_COND_BATCH = "calc_cond_batch" - APPLY_MODEL = "apply_model" - DIFFUSION_MODEL = "diffusion_model" - - # wrappers dict is in the format: - # {"wrapper_type": {"key": [Callable1, Callable2, ...]} } - @classmethod - def init_wrappers(cls) -> dict[str, dict[str, list[Callable]]]: - return {} - -def add_wrapper(wrapper_type: str, wrapper: Callable, transformer_options: dict, is_model_options=False): - add_wrapper_with_key(wrapper_type, None, wrapper, transformer_options, is_model_options) - -def add_wrapper_with_key(wrapper_type: str, key: str, wrapper: Callable, transformer_options: dict, is_model_options=False): - if is_model_options: - transformer_options = transformer_options.setdefault("transformer_options", {}) - wrappers: dict[str, dict[str, list]] = transformer_options.setdefault("wrappers", {}) - w = wrappers.setdefault(wrapper_type, {}).setdefault(key, []) - w.append(wrapper) - -def get_wrappers_with_key(wrapper_type: str, key: str, transformer_options: dict, is_model_options=False): - if is_model_options: - transformer_options = transformer_options.get("transformer_options", {}) - w_list = [] - wrappers: dict[str, list] = transformer_options.get("wrappers", {}) - w_list.extend(wrappers.get(wrapper_type, {}).get(key, [])) - return w_list - -def get_all_wrappers(wrapper_type: str, transformer_options: dict, is_model_options=False): - if is_model_options: - transformer_options = transformer_options.get("transformer_options", {}) - w_list = [] - wrappers: dict[str, list] = transformer_options.get("wrappers", {}) - for w in wrappers.get(wrapper_type, {}).values(): - w_list.extend(w) - return w_list - -class WrapperExecutor: - """Handles call stack of wrappers around a function in an ordered manner.""" - def __init__(self, original: Callable, class_obj: object, wrappers: list[Callable], idx: int): - # NOTE: class_obj exists so that wrappers surrounding a class method can access - # the class instance at runtime via executor.class_obj - self.original = original - self.class_obj = class_obj - self.wrappers = wrappers.copy() - self.idx = idx - self.is_last = idx == len(wrappers) - - def __call__(self, *args, **kwargs): - """Calls the next wrapper or original function, whichever is appropriate.""" - new_executor = self._create_next_executor() - return new_executor.execute(*args, **kwargs) - - def execute(self, *args, **kwargs): - """Used to initiate executor internally - DO NOT use this if you received executor in wrapper.""" - args = list(args) - kwargs = dict(kwargs) - if self.is_last: - return self.original(*args, **kwargs) - return self.wrappers[self.idx](self, *args, **kwargs) - - def _create_next_executor(self) -> 'WrapperExecutor': - new_idx = self.idx + 1 - if new_idx > len(self.wrappers): - raise Exception(f"Wrapper idx exceeded available wrappers; something went very wrong.") - if self.class_obj is None: - return WrapperExecutor.new_executor(self.original, self.wrappers, new_idx) - return WrapperExecutor.new_class_executor(self.original, self.class_obj, self.wrappers, new_idx) - - @classmethod - def new_executor(cls, original: Callable, wrappers: list[Callable], idx=0): - return cls(original, class_obj=None, wrappers=wrappers, idx=idx) - - @classmethod - def new_class_executor(cls, original: Callable, class_obj: object, wrappers: list[Callable], idx=0): - return cls(original, class_obj, wrappers, idx=idx) - -class PatcherInjection: - def __init__(self, inject: Callable, eject: Callable): - self.inject = inject - self.eject = eject - -def copy_nested_dicts(input_dict: dict): - new_dict = input_dict.copy() - for key, value in input_dict.items(): - if isinstance(value, dict): - new_dict[key] = copy_nested_dicts(value) - elif isinstance(value, list): - new_dict[key] = value.copy() - return new_dict - -def merge_nested_dicts(dict1: dict, dict2: dict, copy_dict1=True): - if copy_dict1: - merged_dict = copy_nested_dicts(dict1) - else: - merged_dict = dict1 - for key, value in dict2.items(): - if isinstance(value, dict): - curr_value = merged_dict.setdefault(key, {}) - merged_dict[key] = merge_nested_dicts(value, curr_value) - elif isinstance(value, list): - merged_dict.setdefault(key, []).extend(value) - else: - merged_dict[key] = value - return merged_dict diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 1252d8a5b..1879e670a 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -1,16 +1,7 @@ -from __future__ import annotations -import uuid import torch import comfy.model_management import comfy.conds import comfy.utils -import comfy.hooks -import comfy.patcher_extension -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from comfy.model_patcher import ModelPatcher - from comfy.model_base import BaseModel - from comfy.controlnet import ControlBase def prepare_mask(noise_mask, shape, device): return comfy.utils.reshape_mask(noise_mask, shape).to(device) @@ -19,43 +10,9 @@ def get_models_from_cond(cond, model_type): models = [] for c in cond: if model_type in c: - if isinstance(c[model_type], list): - models += c[model_type] - else: - models += [c[model_type]] + models += [c[model_type]] return models -def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]]): - # get hooks from conds, and collect cnets so they can be checked for extra_hooks - cnets: list[ControlBase] = [] - for c in cond: - if 'hooks' in c: - for hook in c['hooks'].hooks: - hook: comfy.hooks.Hook - with_type = hooks_dict.setdefault(hook.hook_type, {}) - with_type[hook] = None - if 'control' in c: - cnets.append(c['control']) - - def get_extra_hooks_from_cnet(cnet: ControlBase, _list: list): - if cnet.extra_hooks is not None: - _list.append(cnet.extra_hooks) - if cnet.previous_controlnet is None: - return _list - return get_extra_hooks_from_cnet(cnet.previous_controlnet, _list) - - hooks_list = [] - cnets = set(cnets) - for base_cnet in cnets: - get_extra_hooks_from_cnet(base_cnet, hooks_list) - extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list) - if extra_hooks is not None: - for hook in extra_hooks.hooks: - with_type = hooks_dict.setdefault(hook.hook_type, {}) - with_type[hook] = None - - return hooks_dict - def convert_cond(cond): out = [] for c in cond: @@ -65,22 +22,17 @@ def convert_cond(cond): model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove temp["cross_attn"] = c[0] temp["model_conds"] = model_conds - temp["uuid"] = uuid.uuid4() out.append(temp) return out def get_additional_models(conds, dtype): """loads additional models in conditioning""" - cnets: list[ControlBase] = [] + cnets = [] gligen = [] - add_models = [] - hooks: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]] = {} for k in conds: cnets += get_models_from_cond(conds[k], "control") gligen += get_models_from_cond(conds[k], "gligen") - add_models += get_models_from_cond(conds[k], "additional_models") - get_hooks_from_cond(conds[k], hooks) control_nets = set(cnets) @@ -91,9 +43,7 @@ def get_additional_models(conds, dtype): inference_memory += m.inference_memory_requirements(dtype) gligen = [x[1] for x in gligen] - hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()] - models = control_models + gligen + add_models + hook_models - + models = control_models + gligen return models, inference_memory def cleanup_additional_models(models): @@ -103,11 +53,10 @@ def cleanup_additional_models(models): m.cleanup() -def prepare_sampling(model: 'ModelPatcher', noise_shape, conds): +def prepare_sampling(model, noise_shape, conds): device = model.load_device - real_model: 'BaseModel' = None + real_model = None models, inference_memory = get_additional_models(conds, model.model_dtype()) - models += model.get_nested_additional_models() # TODO: does this require inference_memory update? memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required) @@ -123,14 +72,3 @@ def cleanup_models(conds, models): control_cleanup += get_models_from_cond(conds[k], "control") cleanup_additional_models(set(control_cleanup)) - -def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): - # check for hooks in conds - if not registered, see if can be applied - hooks = {} - for k in conds: - get_hooks_from_cond(conds[k], hooks) - # add wrappers and callbacks from ModelPatcher to transformer_options - model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers) - model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks) - # register hooks on model/model_options - model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model, model_options) diff --git a/comfy/samplers.py b/comfy/samplers.py index b4c42160d..94cba03b8 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1,21 +1,11 @@ -from __future__ import annotations from .k_diffusion import sampling as k_diffusion_sampling from .extra_samplers import uni_pc -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from comfy.model_patcher import ModelPatcher - from comfy.model_base import BaseModel - from comfy.controlnet import ControlBase import torch import collections from comfy import model_management import math import logging -import comfy.samplers import comfy.sampler_helpers -import comfy.model_patcher -import comfy.patcher_extension -import comfy.hooks import scipy.stats import numpy @@ -80,7 +70,6 @@ def get_area_and_mult(conds, x_in, timestep_in): for c in model_conds: conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) - hooks = conds.get('hooks', None) control = conds.get('control', None) patches = None @@ -96,8 +85,8 @@ def get_area_and_mult(conds, x_in, timestep_in): patches['middle_patch'] = [gligen_patch] - cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches', 'uuid', 'hooks']) - return cond_obj(input_x, mult, conditioning, area, control, patches, conds['uuid'], hooks) + cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches']) + return cond_obj(input_x, mult, conditioning, area, control, patches) def cond_equal_size(c1, c2): if c1 is c2: @@ -149,184 +138,110 @@ def cond_cat(c_list): return out -def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep): - # need to figure out remaining unmasked area for conds - default_mults = [] - for _ in default_conds: - default_mults.append(torch.ones_like(x_in)) - # look through each finalized cond in hooked_to_run for 'mult' and subtract it from each cond - for lora_hooks, to_run in hooked_to_run.items(): - for cond_obj, i in to_run: - # if no default_cond for cond_type, do nothing - if len(default_conds[i]) == 0: - continue - area: list[int] = cond_obj.area - if area is not None: - curr_default_mult: torch.Tensor = default_mults[i] - dims = len(area) // 2 - for i in range(dims): - curr_default_mult = curr_default_mult.narrow(i + 2, area[i + dims], area[i]) - curr_default_mult -= cond_obj.mult - else: - default_mults[i] -= cond_obj.mult - # for each default_mult, ReLU to make negatives=0, and then check for any nonzeros - for i, mult in enumerate(default_mults): - # if no default_cond for cond type, do nothing - if len(default_conds[i]) == 0: - continue - torch.nn.functional.relu(mult, inplace=True) - # if mult is all zeros, then don't add default_cond - if torch.max(mult) == 0.0: - continue - - cond = default_conds[i] - for x in cond: - # do get_area_and_mult to get all the expected values - p = comfy.samplers.get_area_and_mult(x, x_in, timestep) - if p is None: - continue - # replace p's mult with calculated mult - p = p._replace(mult=mult) - if p.hooks is not None: - model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks) - hooked_to_run.setdefault(p.hooks, list()) - hooked_to_run[p.hooks] += [(p, i)] - -def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): - executor = comfy.patcher_extension.WrapperExecutor.new_executor( - _calc_cond_batch, - comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True) - ) - return executor.execute(model, conds, x_in, timestep, model_options) - -def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): +def calc_cond_batch(model, conds, x_in, timestep, model_options): out_conds = [] out_counts = [] - # separate conds by matching hooks - hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {} - default_conds = [] - has_default_conds = False + to_run = [] for i in range(len(conds)): out_conds.append(torch.zeros_like(x_in)) out_counts.append(torch.ones_like(x_in) * 1e-37) cond = conds[i] - default_c = [] if cond is not None: for x in cond: - if 'default' in x: - default_c.append(x) - has_default_conds = True - continue - p = comfy.samplers.get_area_and_mult(x, x_in, timestep) + p = get_area_and_mult(x, x_in, timestep) if p is None: continue - if p.hooks is not None: - model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks) - hooked_to_run.setdefault(p.hooks, list()) - hooked_to_run[p.hooks] += [(p, i)] - default_conds.append(default_c) - if has_default_conds: - finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep) + to_run += [(p, i)] - model.current_patcher.prepare_state(timestep) + while len(to_run) > 0: + first = to_run[0] + first_shape = first[0][0].shape + to_batch_temp = [] + for x in range(len(to_run)): + if can_concat_cond(to_run[x][0], first[0]): + to_batch_temp += [x] - # run every hooked_to_run separately - for hooks, to_run in hooked_to_run.items(): - while len(to_run) > 0: - first = to_run[0] - first_shape = first[0][0].shape - to_batch_temp = [] - for x in range(len(to_run)): - if can_concat_cond(to_run[x][0], first[0]): - to_batch_temp += [x] + to_batch_temp.reverse() + to_batch = to_batch_temp[:1] - to_batch_temp.reverse() - to_batch = to_batch_temp[:1] + free_memory = model_management.get_free_memory(x_in.device) + for i in range(1, len(to_batch_temp) + 1): + batch_amount = to_batch_temp[:len(to_batch_temp)//i] + input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] + if model.memory_required(input_shape) * 1.5 < free_memory: + to_batch = batch_amount + break - free_memory = model_management.get_free_memory(x_in.device) - for i in range(1, len(to_batch_temp) + 1): - batch_amount = to_batch_temp[:len(to_batch_temp)//i] - input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] - if model.memory_required(input_shape) * 1.5 < free_memory: - to_batch = batch_amount - break + input_x = [] + mult = [] + c = [] + cond_or_uncond = [] + area = [] + control = None + patches = None + for x in to_batch: + o = to_run.pop(x) + p = o[0] + input_x.append(p.input_x) + mult.append(p.mult) + c.append(p.conditioning) + area.append(p.area) + cond_or_uncond.append(o[1]) + control = p.control + patches = p.patches - input_x = [] - mult = [] - c = [] - cond_or_uncond = [] - uuids = [] - area = [] - control = None - patches = None - for x in to_batch: - o = to_run.pop(x) - p = o[0] - input_x.append(p.input_x) - mult.append(p.mult) - c.append(p.conditioning) - area.append(p.area) - cond_or_uncond.append(o[1]) - uuids.append(p.uuid) - control = p.control - patches = p.patches + batch_chunks = len(cond_or_uncond) + input_x = torch.cat(input_x) + c = cond_cat(c) + timestep_ = torch.cat([timestep] * batch_chunks) - batch_chunks = len(cond_or_uncond) - input_x = torch.cat(input_x) - c = cond_cat(c) - timestep_ = torch.cat([timestep] * batch_chunks) + if control is not None: + c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond)) - transformer_options = model.current_patcher.apply_hooks(hooks=hooks) - if 'transformer_options' in model_options: - transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options, - model_options['transformer_options'], - copy_dict1=False) + transformer_options = {} + if 'transformer_options' in model_options: + transformer_options = model_options['transformer_options'].copy() - if patches is not None: - # TODO: replace with merge_nested_dicts function - if "patches" in transformer_options: - cur_patches = transformer_options["patches"].copy() - for p in patches: - if p in cur_patches: - cur_patches[p] = cur_patches[p] + patches[p] - else: - cur_patches[p] = patches[p] - transformer_options["patches"] = cur_patches - else: - transformer_options["patches"] = patches - - transformer_options["cond_or_uncond"] = cond_or_uncond[:] - transformer_options["uuids"] = uuids[:] - transformer_options["sigmas"] = timestep - - c['transformer_options'] = transformer_options - - if control is not None: - c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options) - - if 'model_function_wrapper' in model_options: - output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) + if patches is not None: + if "patches" in transformer_options: + cur_patches = transformer_options["patches"].copy() + for p in patches: + if p in cur_patches: + cur_patches[p] = cur_patches[p] + patches[p] + else: + cur_patches[p] = patches[p] + transformer_options["patches"] = cur_patches else: - output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) + transformer_options["patches"] = patches - for o in range(batch_chunks): - cond_index = cond_or_uncond[o] - a = area[o] - if a is None: - out_conds[cond_index] += output[o] * mult[o] - out_counts[cond_index] += mult[o] - else: - out_c = out_conds[cond_index] - out_cts = out_counts[cond_index] - dims = len(a) // 2 - for i in range(dims): - out_c = out_c.narrow(i + 2, a[i + dims], a[i]) - out_cts = out_cts.narrow(i + 2, a[i + dims], a[i]) - out_c += output[o] * mult[o] - out_cts += mult[o] + transformer_options["cond_or_uncond"] = cond_or_uncond[:] + transformer_options["sigmas"] = timestep + + c['transformer_options'] = transformer_options + + if 'model_function_wrapper' in model_options: + output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) + else: + output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) + + for o in range(batch_chunks): + cond_index = cond_or_uncond[o] + a = area[o] + if a is None: + out_conds[cond_index] += output[o] * mult[o] + out_counts[cond_index] += mult[o] + else: + out_c = out_conds[cond_index] + out_cts = out_counts[cond_index] + dims = len(a) // 2 + for i in range(dims): + out_c = out_c.narrow(i + 2, a[i + dims], a[i]) + out_cts = out_cts.narrow(i + 2, a[i + dims], a[i]) + out_c += output[o] * mult[o] + out_cts += mult[o] for i in range(len(out_conds)): out_conds[i] /= out_counts[i] @@ -585,15 +500,10 @@ def calculate_start_end_timesteps(model, conds): timestep_start = None timestep_end = None - # handle clip hook schedule, if needed - if 'clip_start_percent' in x: - timestep_start = s.percent_to_sigma(max(x['clip_start_percent'], x.get('start_percent', 0.0))) - timestep_end = s.percent_to_sigma(min(x['clip_end_percent'], x.get('end_percent', 1.0))) - else: - if 'start_percent' in x: - timestep_start = s.percent_to_sigma(x['start_percent']) - if 'end_percent' in x: - timestep_end = s.percent_to_sigma(x['end_percent']) + if 'start_percent' in x: + timestep_start = s.percent_to_sigma(x['start_percent']) + if 'end_percent' in x: + timestep_end = s.percent_to_sigma(x['end_percent']) if (timestep_start is not None) or (timestep_end is not None): n = x.copy() @@ -763,12 +673,6 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N if k != kk: create_cond_with_same_area_if_none(conds[kk], c) - for k in conds: - for c in conds[k]: - if 'hooks' in c: - for hook in c['hooks'].hooks: - hook.initialize_timesteps(model) - for k in conds: pre_run_control(model, conds[k]) @@ -781,46 +685,9 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N return conds - -def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]): - # determine which ControlNets have extra_hooks that should be combined with normal hooks - hook_replacement: dict[tuple[ControlBase, comfy.hooks.HookGroup], list[dict]] = {} - for k in conds: - for kk in conds[k]: - if 'control' in kk: - control: 'ControlBase' = kk['control'] - extra_hooks = control.get_extra_hooks() - if len(extra_hooks) > 0: - hooks: comfy.hooks.HookGroup = kk.get('hooks', None) - to_replace = hook_replacement.setdefault((control, hooks), []) - to_replace.append(kk) - # if nothing to replace, do nothing - if len(hook_replacement) == 0: - return - - # for optimal sampling performance, common ControlNets + hook combos should have identical hooks - # on the cond dicts - for key, conds_to_modify in hook_replacement.items(): - control = key[0] - hooks = key[1] - hooks = comfy.hooks.HookGroup.combine_all_hooks(control.get_extra_hooks() + [hooks]) - # if combined hooks are not None, set as new hooks for all relevant conds - if hooks is not None: - for cond in conds_to_modify: - cond['hooks'] = hooks - - -def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]): - hooks_set = set() - for k in conds: - for kk in conds[k]: - hooks_set.add(kk.get('hooks', None)) - return len(hooks_set) - - class CFGGuider: def __init__(self, model_patcher): - self.model_patcher: 'ModelPatcher' = model_patcher + self.model_patcher = model_patcher self.model_options = model_patcher.model_options self.original_conds = {} self.cfg = 1.0 @@ -847,17 +714,19 @@ class CFGGuider: self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed) - extra_args = {"model_options": comfy.model_patcher.create_model_options_clone(self.model_options), "seed": seed} + extra_args = {"model_options": self.model_options, "seed":seed} - executor = comfy.patcher_extension.WrapperExecutor.new_class_executor( - sampler.sample, - sampler, - comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, extra_args["model_options"], is_model_options=True) - ) - samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) + samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) return self.inner_model.process_latent_out(samples.to(torch.float32)) - def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): + def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): + if sigmas.shape[-1] == 0: + return latent_image + + self.conds = {} + for k in self.original_conds: + self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) + self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds) device = self.model_patcher.load_device @@ -868,46 +737,12 @@ class CFGGuider: latent_image = latent_image.to(device) sigmas = sigmas.to(device) - try: - self.model_patcher.pre_run() - output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) - finally: - self.model_patcher.cleanup() + output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models) del self.inner_model - del self.loaded_models - return output - - def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): - if sigmas.shape[-1] == 0: - return latent_image - - self.conds = {} - for k in self.original_conds: - self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) - preprocess_conds_hooks(self.conds) - - try: - orig_model_options = self.model_options - self.model_options = comfy.model_patcher.create_model_options_clone(self.model_options) - # if one hook type (or just None), then don't bother caching weights for hooks (will never change after first step) - orig_hook_mode = self.model_patcher.hook_mode - if get_total_hook_groups_in_conds(self.conds) <= 1: - self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram - comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options) - executor = comfy.patcher_extension.WrapperExecutor.new_class_executor( - self.outer_sample, - self, - comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True) - ) - output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) - finally: - self.model_options = orig_model_options - self.model_patcher.hook_mode = orig_hook_mode - self.model_patcher.restore_hook_patches() - del self.conds + del self.loaded_models return output diff --git a/comfy/sd.py b/comfy/sd.py index ebae7f996..e2af70781 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,10 +1,8 @@ -from __future__ import annotations import torch from enum import Enum import logging from comfy import model_management -from comfy.utils import ProgressBar from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine from .ldm.cascade.stage_a import StageA from .ldm.cascade.stage_c_coder import StageC_coder @@ -35,7 +33,6 @@ import comfy.text_encoders.lt import comfy.model_patcher import comfy.lora import comfy.lora_convert -import comfy.hooks import comfy.t2i_adapter.adapter import comfy.taesd.taesd @@ -101,13 +98,9 @@ class CLIP: self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) - self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram - self.patcher.is_clip = True - self.apply_hooks_to_conds = None if params['device'] == load_device: model_management.load_models_gpu([self.patcher], force_full_load=True) self.layer_idx = None - self.use_clip_schedule = False logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device'])) def clone(self): @@ -116,8 +109,6 @@ class CLIP: n.cond_stage_model = self.cond_stage_model n.tokenizer = self.tokenizer n.layer_idx = self.layer_idx - n.use_clip_schedule = self.use_clip_schedule - n.apply_hooks_to_conds = self.apply_hooks_to_conds return n def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): @@ -129,69 +120,6 @@ class CLIP: def tokenize(self, text, return_word_ids=False): return self.tokenizer.tokenize_with_weights(text, return_word_ids) - def add_hooks_to_dict(self, pooled_dict: dict[str]): - if self.apply_hooks_to_conds: - pooled_dict["hooks"] = self.apply_hooks_to_conds - return pooled_dict - - def encode_from_tokens_scheduled(self, tokens, unprojected=False, add_dict: dict[str]={}, show_pbar=True): - all_cond_pooled: list[tuple[torch.Tensor, dict[str]]] = [] - all_hooks = self.patcher.forced_hooks - if all_hooks is None or not self.use_clip_schedule: - # if no hooks or shouldn't use clip schedule, do unscheduled encode_from_tokens and perform add_dict - return_pooled = "unprojected" if unprojected else True - pooled_dict = self.encode_from_tokens(tokens, return_pooled=return_pooled, return_dict=True) - cond = pooled_dict.pop("cond") - # add/update any keys with the provided add_dict - pooled_dict.update(add_dict) - all_cond_pooled.append([cond, pooled_dict]) - else: - scheduled_keyframes = all_hooks.get_hooks_for_clip_schedule() - - self.cond_stage_model.reset_clip_options() - if self.layer_idx is not None: - self.cond_stage_model.set_clip_options({"layer": self.layer_idx}) - if unprojected: - self.cond_stage_model.set_clip_options({"projected_pooled": False}) - - self.load_model() - all_hooks.reset() - self.patcher.patch_hooks(None) - if show_pbar: - pbar = ProgressBar(len(scheduled_keyframes)) - - for scheduled_opts in scheduled_keyframes: - t_range = scheduled_opts[0] - # don't bother encoding any conds outside of start_percent and end_percent bounds - if "start_percent" in add_dict: - if t_range[1] < add_dict["start_percent"]: - continue - if "end_percent" in add_dict: - if t_range[0] > add_dict["end_percent"]: - continue - hooks_keyframes = scheduled_opts[1] - for hook, keyframe in hooks_keyframes: - hook.hook_keyframe._current_keyframe = keyframe - # apply appropriate hooks with values that match new hook_keyframe - self.patcher.patch_hooks(all_hooks) - # perform encoding as normal - o = self.cond_stage_model.encode_token_weights(tokens) - cond, pooled = o[:2] - pooled_dict = {"pooled_output": pooled} - # add clip_start_percent and clip_end_percent in pooled - pooled_dict["clip_start_percent"] = t_range[0] - pooled_dict["clip_end_percent"] = t_range[1] - # add/update any keys with the provided add_dict - pooled_dict.update(add_dict) - # add hooks stored on clip - self.add_hooks_to_dict(pooled_dict) - all_cond_pooled.append([cond, pooled_dict]) - if show_pbar: - pbar.update(1) - model_management.throw_exception_if_processing_interrupted() - all_hooks.reset() - return all_cond_pooled - def encode_from_tokens(self, tokens, return_pooled=False, return_dict=False): self.cond_stage_model.reset_clip_options() @@ -209,7 +137,6 @@ class CLIP: if len(o) > 2: for k in o[2]: out[k] = o[2][k] - self.add_hooks_to_dict(out) return out if return_pooled: diff --git a/comfy_extras/nodes_clip_sdxl.py b/comfy_extras/nodes_clip_sdxl.py index b8e241578..3087b917b 100644 --- a/comfy_extras/nodes_clip_sdxl.py +++ b/comfy_extras/nodes_clip_sdxl.py @@ -17,7 +17,8 @@ class CLIPTextEncodeSDXLRefiner: def encode(self, clip, ascore, width, height, text): tokens = clip.tokenize(text) - return (clip.encode_from_tokens_scheduled(tokens, add_dict={"aesthetic_score": ascore, "width": width, "height": height}), ) + cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) + return ([[cond, {"pooled_output": pooled, "aesthetic_score": ascore, "width": width,"height": height}]], ) class CLIPTextEncodeSDXL: @classmethod @@ -46,7 +47,8 @@ class CLIPTextEncodeSDXL: tokens["l"] += empty["l"] while len(tokens["l"]) > len(tokens["g"]): tokens["g"] += empty["g"] - return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}), ) + cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) + return ([[cond, {"pooled_output": pooled, "width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}]], ) NODE_CLASS_MAPPINGS = { "CLIPTextEncodeSDXLRefiner": CLIPTextEncodeSDXLRefiner, diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index 2ae23f735..b690432b5 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -18,7 +18,10 @@ class CLIPTextEncodeFlux: tokens = clip.tokenize(clip_l) tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"] - return (clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}), ) + output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True) + cond = output.pop("cond") + output["guidance"] = guidance + return ([[cond, output]], ) class FluxGuidance: @classmethod diff --git a/comfy_extras/nodes_hooks.py b/comfy_extras/nodes_hooks.py deleted file mode 100644 index 7b5344e5e..000000000 --- a/comfy_extras/nodes_hooks.py +++ /dev/null @@ -1,697 +0,0 @@ -from __future__ import annotations -from typing import TYPE_CHECKING, Union -import torch -from collections.abc import Iterable - -if TYPE_CHECKING: - from comfy.model_patcher import ModelPatcher - from comfy.sd import CLIP - -import comfy.hooks -import comfy.sd -import comfy.utils -import folder_paths - -########################################### -# Mask, Combine, and Hook Conditioning -#------------------------------------------ -class PairConditioningSetProperties: - NodeId = 'PairConditioningSetProperties' - NodeName = 'Cond Pair Set Props' - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "positive_NEW": ("CONDITIONING", ), - "negative_NEW": ("CONDITIONING", ), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "set_cond_area": (["default", "mask bounds"],), - }, - "optional": { - "mask": ("MASK", ), - "hooks": ("HOOKS",), - "timesteps": ("TIMESTEPS_RANGE",), - } - } - - EXPERIMENTAL = True - RETURN_TYPES = ("CONDITIONING", "CONDITIONING") - RETURN_NAMES = ("positive", "negative") - CATEGORY = "advanced/hooks/cond pair" - FUNCTION = "set_properties" - - def set_properties(self, positive_NEW, negative_NEW, - strength: float, set_cond_area: str, - mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None): - final_positive, final_negative = comfy.hooks.set_conds_props(conds=[positive_NEW, negative_NEW], - strength=strength, set_cond_area=set_cond_area, - mask=mask, hooks=hooks, timesteps_range=timesteps) - return (final_positive, final_negative) - -class PairConditioningSetPropertiesAndCombine: - NodeId = 'PairConditioningSetPropertiesAndCombine' - NodeName = 'Cond Pair Set Props Combine' - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "positive_NEW": ("CONDITIONING", ), - "negative_NEW": ("CONDITIONING", ), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "set_cond_area": (["default", "mask bounds"],), - }, - "optional": { - "mask": ("MASK", ), - "hooks": ("HOOKS",), - "timesteps": ("TIMESTEPS_RANGE",), - } - } - - EXPERIMENTAL = True - RETURN_TYPES = ("CONDITIONING", "CONDITIONING") - RETURN_NAMES = ("positive", "negative") - CATEGORY = "advanced/hooks/cond pair" - FUNCTION = "set_properties" - - def set_properties(self, positive, negative, positive_NEW, negative_NEW, - strength: float, set_cond_area: str, - mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None): - final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive, negative], new_conds=[positive_NEW, negative_NEW], - strength=strength, set_cond_area=set_cond_area, - mask=mask, hooks=hooks, timesteps_range=timesteps) - return (final_positive, final_negative) - -class ConditioningSetProperties: - NodeId = 'ConditioningSetProperties' - NodeName = 'Cond Set Props' - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "cond_NEW": ("CONDITIONING", ), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "set_cond_area": (["default", "mask bounds"],), - }, - "optional": { - "mask": ("MASK", ), - "hooks": ("HOOKS",), - "timesteps": ("TIMESTEPS_RANGE",), - } - } - - EXPERIMENTAL = True - RETURN_TYPES = ("CONDITIONING",) - CATEGORY = "advanced/hooks/cond single" - FUNCTION = "set_properties" - - def set_properties(self, cond_NEW, - strength: float, set_cond_area: str, - mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None): - (final_cond,) = comfy.hooks.set_conds_props(conds=[cond_NEW], - strength=strength, set_cond_area=set_cond_area, - mask=mask, hooks=hooks, timesteps_range=timesteps) - return (final_cond,) - -class ConditioningSetPropertiesAndCombine: - NodeId = 'ConditioningSetPropertiesAndCombine' - NodeName = 'Cond Set Props Combine' - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "cond": ("CONDITIONING", ), - "cond_NEW": ("CONDITIONING", ), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "set_cond_area": (["default", "mask bounds"],), - }, - "optional": { - "mask": ("MASK", ), - "hooks": ("HOOKS",), - "timesteps": ("TIMESTEPS_RANGE",), - } - } - - EXPERIMENTAL = True - RETURN_TYPES = ("CONDITIONING",) - CATEGORY = "advanced/hooks/cond single" - FUNCTION = "set_properties" - - def set_properties(self, cond, cond_NEW, - strength: float, set_cond_area: str, - mask: torch.Tensor=None, hooks: comfy.hooks.HookGroup=None, timesteps: tuple=None): - (final_cond,) = comfy.hooks.set_conds_props_and_combine(conds=[cond], new_conds=[cond_NEW], - strength=strength, set_cond_area=set_cond_area, - mask=mask, hooks=hooks, timesteps_range=timesteps) - return (final_cond,) - -class PairConditioningCombine: - NodeId = 'PairConditioningCombine' - NodeName = 'Cond Pair Combine' - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "positive_A": ("CONDITIONING",), - "negative_A": ("CONDITIONING",), - "positive_B": ("CONDITIONING",), - "negative_B": ("CONDITIONING",), - }, - } - - EXPERIMENTAL = True - RETURN_TYPES = ("CONDITIONING", "CONDITIONING") - RETURN_NAMES = ("positive", "negative") - CATEGORY = "advanced/hooks/cond pair" - FUNCTION = "combine" - - def combine(self, positive_A, negative_A, positive_B, negative_B): - final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive_A, negative_A], new_conds=[positive_B, negative_B],) - return (final_positive, final_negative,) - -class PairConditioningSetDefaultAndCombine: - NodeId = 'PairConditioningSetDefaultCombine' - NodeName = 'Cond Pair Set Default Combine' - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "positive": ("CONDITIONING",), - "negative": ("CONDITIONING",), - "positive_DEFAULT": ("CONDITIONING",), - "negative_DEFAULT": ("CONDITIONING",), - }, - "optional": { - "hooks": ("HOOKS",), - } - } - - EXPERIMENTAL = True - RETURN_TYPES = ("CONDITIONING", "CONDITIONING") - RETURN_NAMES = ("positive", "negative") - CATEGORY = "advanced/hooks/cond pair" - FUNCTION = "set_default_and_combine" - - def set_default_and_combine(self, positive, negative, positive_DEFAULT, negative_DEFAULT, - hooks: comfy.hooks.HookGroup=None): - final_positive, final_negative = comfy.hooks.set_default_conds_and_combine(conds=[positive, negative], new_conds=[positive_DEFAULT, negative_DEFAULT], - hooks=hooks) - return (final_positive, final_negative) - -class ConditioningSetDefaultAndCombine: - NodeId = 'ConditioningSetDefaultCombine' - NodeName = 'Cond Set Default Combine' - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "cond": ("CONDITIONING",), - "cond_DEFAULT": ("CONDITIONING",), - }, - "optional": { - "hooks": ("HOOKS",), - } - } - - EXPERIMENTAL = True - RETURN_TYPES = ("CONDITIONING",) - CATEGORY = "advanced/hooks/cond single" - FUNCTION = "set_default_and_combine" - - def set_default_and_combine(self, cond, cond_DEFAULT, - hooks: comfy.hooks.HookGroup=None): - (final_conditioning,) = comfy.hooks.set_default_conds_and_combine(conds=[cond], new_conds=[cond_DEFAULT], - hooks=hooks) - return (final_conditioning,) - -class SetClipHooks: - NodeId = 'SetClipHooks' - NodeName = 'Set CLIP Hooks' - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "clip": ("CLIP",), - "apply_to_conds": ("BOOLEAN", {"default": True}), - "schedule_clip": ("BOOLEAN", {"default": False}) - }, - "optional": { - "hooks": ("HOOKS",) - } - } - - EXPERIMENTAL = True - RETURN_TYPES = ("CLIP",) - CATEGORY = "advanced/hooks/clip" - FUNCTION = "apply_hooks" - - def apply_hooks(self, clip: 'CLIP', schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None): - if hooks is not None: - clip = clip.clone() - if apply_to_conds: - clip.apply_hooks_to_conds = hooks - clip.patcher.forced_hooks = hooks.clone() - clip.use_clip_schedule = schedule_clip - if not clip.use_clip_schedule: - clip.patcher.forced_hooks.set_keyframes_on_hooks(None) - clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.EnumWeightTarget.Clip) - return (clip,) - -class ConditioningTimestepsRange: - NodeId = 'ConditioningTimestepsRange' - NodeName = 'Timesteps Range' - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), - "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) - }, - } - - EXPERIMENTAL = True - RETURN_TYPES = ("TIMESTEPS_RANGE", "TIMESTEPS_RANGE", "TIMESTEPS_RANGE") - RETURN_NAMES = ("TIMESTEPS_RANGE", "BEFORE_RANGE", "AFTER_RANGE") - CATEGORY = "advanced/hooks" - FUNCTION = "create_range" - - def create_range(self, start_percent: float, end_percent: float): - return ((start_percent, end_percent), (0.0, start_percent), (end_percent, 1.0)) -#------------------------------------------ -########################################### - - -########################################### -# Create Hooks -#------------------------------------------ -class CreateHookLora: - NodeId = 'CreateHookLora' - NodeName = 'Create Hook LoRA' - def __init__(self): - self.loaded_lora = None - - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "lora_name": (folder_paths.get_filename_list("loras"), ), - "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), - "strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), - }, - "optional": { - "prev_hooks": ("HOOKS",) - } - } - - EXPERIMENTAL = True - RETURN_TYPES = ("HOOKS",) - CATEGORY = "advanced/hooks/create" - FUNCTION = "create_hook" - - def create_hook(self, lora_name: str, strength_model: float, strength_clip: float, prev_hooks: comfy.hooks.HookGroup=None): - if prev_hooks is None: - prev_hooks = comfy.hooks.HookGroup() - prev_hooks.clone() - - if strength_model == 0 and strength_clip == 0: - return (prev_hooks,) - - lora_path = folder_paths.get_full_path("loras", lora_name) - lora = None - if self.loaded_lora is not None: - if self.loaded_lora[0] == lora_path: - lora = self.loaded_lora[1] - else: - temp = self.loaded_lora - self.loaded_lora = None - del temp - - if lora is None: - lora = comfy.utils.load_torch_file(lora_path, safe_load=True) - self.loaded_lora = (lora_path, lora) - - hooks = comfy.hooks.create_hook_lora(lora=lora, strength_model=strength_model, strength_clip=strength_clip) - return (prev_hooks.clone_and_combine(hooks),) - -class CreateHookLoraModelOnly(CreateHookLora): - NodeId = 'CreateHookLoraModelOnly' - NodeName = 'Create Hook LoRA (MO)' - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "lora_name": (folder_paths.get_filename_list("loras"), ), - "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), - }, - "optional": { - "prev_hooks": ("HOOKS",) - } - } - - EXPERIMENTAL = True - RETURN_TYPES = ("HOOKS",) - CATEGORY = "advanced/hooks/create" - FUNCTION = "create_hook_model_only" - - def create_hook_model_only(self, lora_name: str, strength_model: float, prev_hooks: comfy.hooks.HookGroup=None): - return self.create_hook(lora_name=lora_name, strength_model=strength_model, strength_clip=0, prev_hooks=prev_hooks) - -class CreateHookModelAsLora: - NodeId = 'CreateHookModelAsLora' - NodeName = 'Create Hook Model as LoRA' - - def __init__(self): - # when not None, will be in following format: - # (ckpt_path: str, weights_model: dict, weights_clip: dict) - self.loaded_weights = None - - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), - "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), - "strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), - }, - "optional": { - "prev_hooks": ("HOOKS",) - } - } - - EXPERIMENTAL = True - RETURN_TYPES = ("HOOKS",) - CATEGORY = "advanced/hooks/create" - FUNCTION = "create_hook" - - def create_hook(self, ckpt_name: str, strength_model: float, strength_clip: float, - prev_hooks: comfy.hooks.HookGroup=None): - if prev_hooks is None: - prev_hooks = comfy.hooks.HookGroup() - prev_hooks.clone() - - ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) - weights_model = None - weights_clip = None - if self.loaded_weights is not None: - if self.loaded_weights[0] == ckpt_path: - weights_model = self.loaded_weights[1] - weights_clip = self.loaded_weights[2] - else: - temp = self.loaded_weights - self.loaded_weights = None - del temp - - if weights_model is None: - out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) - weights_model = comfy.hooks.get_patch_weights_from_model(out[0]) - weights_clip = comfy.hooks.get_patch_weights_from_model(out[1].patcher if out[1] else out[1]) - self.loaded_weights = (ckpt_path, weights_model, weights_clip) - - hooks = comfy.hooks.create_hook_model_as_lora(weights_model=weights_model, weights_clip=weights_clip, - strength_model=strength_model, strength_clip=strength_clip) - return (prev_hooks.clone_and_combine(hooks),) - -class CreateHookModelAsLoraModelOnly(CreateHookModelAsLora): - NodeId = 'CreateHookModelAsLoraModelOnly' - NodeName = 'Create Hook Model as LoRA (MO)' - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), - "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), - }, - "optional": { - "prev_hooks": ("HOOKS",) - } - } - - EXPERIMENTAL = True - RETURN_TYPES = ("HOOKS",) - CATEGORY = "advanced/hooks/create" - FUNCTION = "create_hook_model_only" - - def create_hook_model_only(self, ckpt_name: str, strength_model: float, - prev_hooks: comfy.hooks.HookGroup=None): - return self.create_hook(ckpt_name=ckpt_name, strength_model=strength_model, strength_clip=0.0, prev_hooks=prev_hooks) -#------------------------------------------ -########################################### - - -########################################### -# Schedule Hooks -#------------------------------------------ -class SetHookKeyframes: - NodeId = 'SetHookKeyframes' - NodeName = 'Set Hook Keyframes' - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "hooks": ("HOOKS",), - }, - "optional": { - "hook_kf": ("HOOK_KEYFRAMES",), - } - } - - EXPERIMENTAL = True - RETURN_TYPES = ("HOOKS",) - CATEGORY = "advanced/hooks/scheduling" - FUNCTION = "set_hook_keyframes" - - def set_hook_keyframes(self, hooks: comfy.hooks.HookGroup, hook_kf: comfy.hooks.HookKeyframeGroup=None): - if hook_kf is not None: - hooks = hooks.clone() - hooks.set_keyframes_on_hooks(hook_kf=hook_kf) - return (hooks,) - -class CreateHookKeyframe: - NodeId = 'CreateHookKeyframe' - NodeName = 'Create Hook Keyframe' - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "strength_mult": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), - "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), - }, - "optional": { - "prev_hook_kf": ("HOOK_KEYFRAMES",), - } - } - - EXPERIMENTAL = True - RETURN_TYPES = ("HOOK_KEYFRAMES",) - RETURN_NAMES = ("HOOK_KF",) - CATEGORY = "advanced/hooks/scheduling" - FUNCTION = "create_hook_keyframe" - - def create_hook_keyframe(self, strength_mult: float, start_percent: float, prev_hook_kf: comfy.hooks.HookKeyframeGroup=None): - if prev_hook_kf is None: - prev_hook_kf = comfy.hooks.HookKeyframeGroup() - prev_hook_kf = prev_hook_kf.clone() - keyframe = comfy.hooks.HookKeyframe(strength=strength_mult, start_percent=start_percent) - prev_hook_kf.add(keyframe) - return (prev_hook_kf,) - -class CreateHookKeyframesFromFloats: - NodeId = 'CreateHookKeyframesFromFloats' - NodeName = 'Create Hook Keyframes From Floats' - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "floats_strength": ("FLOATS", {"default": -1, "min": -1, "step": 0.001, "forceInput": True}), - "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), - "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), - "print_keyframes": ("BOOLEAN", {"default": False}), - }, - "optional": { - "prev_hook_kf": ("HOOK_KEYFRAMES",), - } - } - - EXPERIMENTAL = True - RETURN_TYPES = ("HOOK_KEYFRAMES",) - RETURN_NAMES = ("HOOK_KF",) - CATEGORY = "advanced/hooks/scheduling" - FUNCTION = "create_hook_keyframes" - - def create_hook_keyframes(self, floats_strength: Union[float, list[float]], - start_percent: float, end_percent: float, - prev_hook_kf: comfy.hooks.HookKeyframeGroup=None, print_keyframes=False): - if prev_hook_kf is None: - prev_hook_kf = comfy.hooks.HookKeyframeGroup() - prev_hook_kf = prev_hook_kf.clone() - if type(floats_strength) in (float, int): - floats_strength = [float(floats_strength)] - elif isinstance(floats_strength, Iterable): - pass - else: - raise Exception(f"floats_strength must be either an iterable input or a float, but was{type(floats_strength).__repr__}.") - percents = comfy.hooks.InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=len(floats_strength), - method=comfy.hooks.InterpolationMethod.LINEAR) - - is_first = True - for percent, strength in zip(percents, floats_strength): - guarantee_steps = 0 - if is_first: - guarantee_steps = 1 - is_first = False - prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps)) - if print_keyframes: - print(f"Hook Keyframe - start_percent:{percent} = {strength}") - return (prev_hook_kf,) -#------------------------------------------ -########################################### - - -class SetModelHooksOnCond: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "conditioning": ("CONDITIONING",), - "hooks": ("HOOKS",), - }, - } - - EXPERIMENTAL = True - RETURN_TYPES = ("CONDITIONING",) - CATEGORY = "advanced/hooks/manual" - FUNCTION = "attach_hook" - - def attach_hook(self, conditioning, hooks: comfy.hooks.HookGroup): - return (comfy.hooks.set_hooks_for_conditioning(conditioning, hooks),) - - -########################################### -# Combine Hooks -#------------------------------------------ -class CombineHooks: - NodeId = 'CombineHooks2' - NodeName = 'Combine Hooks [2]' - @classmethod - def INPUT_TYPES(s): - return { - "required": { - }, - "optional": { - "hooks_A": ("HOOKS",), - "hooks_B": ("HOOKS",), - } - } - - EXPERIMENTAL = True - RETURN_TYPES = ("HOOKS",) - CATEGORY = "advanced/hooks/combine" - FUNCTION = "combine_hooks" - - def combine_hooks(self, - hooks_A: comfy.hooks.HookGroup=None, - hooks_B: comfy.hooks.HookGroup=None): - candidates = [hooks_A, hooks_B] - return (comfy.hooks.HookGroup.combine_all_hooks(candidates),) - -class CombineHooksFour: - NodeId = 'CombineHooks4' - NodeName = 'Combine Hooks [4]' - @classmethod - def INPUT_TYPES(s): - return { - "required": { - }, - "optional": { - "hooks_A": ("HOOKS",), - "hooks_B": ("HOOKS",), - "hooks_C": ("HOOKS",), - "hooks_D": ("HOOKS",), - } - } - - EXPERIMENTAL = True - RETURN_TYPES = ("HOOKS",) - CATEGORY = "advanced/hooks/combine" - FUNCTION = "combine_hooks" - - def combine_hooks(self, - hooks_A: comfy.hooks.HookGroup=None, - hooks_B: comfy.hooks.HookGroup=None, - hooks_C: comfy.hooks.HookGroup=None, - hooks_D: comfy.hooks.HookGroup=None): - candidates = [hooks_A, hooks_B, hooks_C, hooks_D] - return (comfy.hooks.HookGroup.combine_all_hooks(candidates),) - -class CombineHooksEight: - NodeId = 'CombineHooks8' - NodeName = 'Combine Hooks [8]' - @classmethod - def INPUT_TYPES(s): - return { - "required": { - }, - "optional": { - "hooks_A": ("HOOKS",), - "hooks_B": ("HOOKS",), - "hooks_C": ("HOOKS",), - "hooks_D": ("HOOKS",), - "hooks_E": ("HOOKS",), - "hooks_F": ("HOOKS",), - "hooks_G": ("HOOKS",), - "hooks_H": ("HOOKS",), - } - } - - EXPERIMENTAL = True - RETURN_TYPES = ("HOOKS",) - CATEGORY = "advanced/hooks/combine" - FUNCTION = "combine_hooks" - - def combine_hooks(self, - hooks_A: comfy.hooks.HookGroup=None, - hooks_B: comfy.hooks.HookGroup=None, - hooks_C: comfy.hooks.HookGroup=None, - hooks_D: comfy.hooks.HookGroup=None, - hooks_E: comfy.hooks.HookGroup=None, - hooks_F: comfy.hooks.HookGroup=None, - hooks_G: comfy.hooks.HookGroup=None, - hooks_H: comfy.hooks.HookGroup=None): - candidates = [hooks_A, hooks_B, hooks_C, hooks_D, hooks_E, hooks_F, hooks_G, hooks_H] - return (comfy.hooks.HookGroup.combine_all_hooks(candidates),) -#------------------------------------------ -########################################### - -node_list = [ - # Create - CreateHookLora, - CreateHookLoraModelOnly, - CreateHookModelAsLora, - CreateHookModelAsLoraModelOnly, - # Scheduling - SetHookKeyframes, - CreateHookKeyframe, - CreateHookKeyframesFromFloats, - # Combine - CombineHooks, - CombineHooksFour, - CombineHooksEight, - # Attach - ConditioningSetProperties, - ConditioningSetPropertiesAndCombine, - PairConditioningSetProperties, - PairConditioningSetPropertiesAndCombine, - ConditioningSetDefaultAndCombine, - PairConditioningSetDefaultAndCombine, - PairConditioningCombine, - SetClipHooks, - # Other - ConditioningTimestepsRange, -] -NODE_CLASS_MAPPINGS = {} -NODE_DISPLAY_NAME_MAPPINGS = {} - -for node in node_list: - NODE_CLASS_MAPPINGS[node.NodeId] = node - NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index 2bd295e24..b03eaf6a2 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -15,7 +15,9 @@ class CLIPTextEncodeHunyuanDiT: tokens = clip.tokenize(bert) tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"] - return (clip.encode_from_tokens_scheduled(tokens), ) + output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True) + cond = output.pop("cond") + return ([[cond, output]], ) NODE_CLASS_MAPPINGS = { diff --git a/comfy_extras/nodes_sd3.py b/comfy_extras/nodes_sd3.py index d75b29e60..6ef3c293d 100644 --- a/comfy_extras/nodes_sd3.py +++ b/comfy_extras/nodes_sd3.py @@ -82,7 +82,8 @@ class CLIPTextEncodeSD3: tokens["l"] += empty["l"] while len(tokens["l"]) > len(tokens["g"]): tokens["g"] += empty["g"] - return (clip.encode_from_tokens_scheduled(tokens), ) + cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) + return ([[cond, {"pooled_output": pooled}]], ) class ControlNetApplySD3(nodes.ControlNetApplyAdvanced): diff --git a/execution.py b/execution.py index 768e35abc..6c386341b 100644 --- a/execution.py +++ b/execution.py @@ -480,7 +480,7 @@ class PromptExecutor: if self.caches.outputs.get(node_id) is not None: cached_nodes.append(node_id) - comfy.model_management.cleanup_models_gc() + comfy.model_management.cleanup_models(keep_clone_weights_loaded=True) self.add_message("execution_cached", { "nodes": cached_nodes, "prompt_id": prompt_id}, broadcast=False) diff --git a/main.py b/main.py index c2c2ff8c8..05eb31c7a 100644 --- a/main.py +++ b/main.py @@ -154,6 +154,7 @@ def prompt_worker(q, server): if need_gc: current_time = time.perf_counter() if (current_time - last_gc_collect) > gc_collect_interval: + comfy.model_management.cleanup_models() gc.collect() comfy.model_management.soft_empty_cache() last_gc_collect = current_time diff --git a/nodes.py b/nodes.py index 260bb5e15..fb504da35 100644 --- a/nodes.py +++ b/nodes.py @@ -62,8 +62,9 @@ class CLIPTextEncode: def encode(self, clip, text): tokens = clip.tokenize(text) - return (clip.encode_from_tokens_scheduled(tokens), ) - + output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True) + cond = output.pop("cond") + return ([[cond, output]], ) class ConditioningCombine: @classmethod @@ -2148,7 +2149,6 @@ def init_builtin_extra_nodes(): "nodes_mochi.py", "nodes_slg.py", "nodes_lt.py", - "nodes_hooks.py", ] import_failed = []