From 72bbf493494109d3c177ad5de378c6c4bbae61d1 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 29 Dec 2024 15:49:09 -0600 Subject: [PATCH 01/52] Add 'sigmas' to transformer_options so that downstream code can know about the full scope of current sampling run, fix Hook Keyframes' guarantee_steps=1 inconsistent behavior with sampling split across different Sampling nodes/sampling runs by referencing 'sigmas' --- comfy/hooks.py | 21 +++++++++++++++++---- comfy/model_patcher.py | 5 +++-- comfy/samplers.py | 10 ++++++---- 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index cf33598ae..79a7090ba 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -366,9 +366,15 @@ class HookKeyframe: self.start_t = 999999999.9 self.guarantee_steps = guarantee_steps + def get_effective_guarantee_steps(self, max_sigma: torch.Tensor): + '''If keyframe starts before current sampling range (max_sigma), treat as 0.''' + if self.start_t > max_sigma: + return 0 + return self.guarantee_steps + def clone(self): c = HookKeyframe(strength=self.strength, - start_percent=self.start_percent, guarantee_steps=self.guarantee_steps) + start_percent=self.start_percent, guarantee_steps=self.guarantee_steps) c.start_t = self.start_t return c @@ -408,6 +414,12 @@ class HookKeyframeGroup: else: self._current_keyframe = None + def has_guarantee_steps(self): + for kf in self.keyframes: + if kf.guarantee_steps > 0: + return True + return False + def has_index(self, index: int): return index >= 0 and index < len(self.keyframes) @@ -425,15 +437,16 @@ class HookKeyframeGroup: for keyframe in self.keyframes: keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) - def prepare_current_keyframe(self, curr_t: float) -> bool: + def prepare_current_keyframe(self, curr_t: float, transformer_options: dict[str, torch.Tensor]) -> bool: if self.is_empty(): return False if curr_t == self._curr_t: return False + max_sigma = torch.max(transformer_options["sigmas"]) prev_index = self._current_index prev_strength = self._current_strength # if met guaranteed steps, look for next keyframe in case need to switch - if self._current_used_steps >= self._current_keyframe.guarantee_steps: + if self._current_used_steps >= self._current_keyframe.get_effective_guarantee_steps(max_sigma): # if has next index, loop through and see if need to switch if self.has_index(self._current_index+1): for i in range(self._current_index+1, len(self.keyframes)): @@ -446,7 +459,7 @@ class HookKeyframeGroup: self._current_keyframe = eval_c self._current_used_steps = 0 # if guarantee_steps greater than zero, stop searching for other keyframes - if self._current_keyframe.guarantee_steps > 0: + if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0: break # if eval_c is outside the percent range, stop looking further else: break diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index d89d9a6a3..4597ce11c 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -919,11 +919,12 @@ class ModelPatcher: def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode): self.hook_mode = hook_mode - def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup): + def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]): curr_t = t[0] reset_current_hooks = False + transformer_options = model_options.get("transformer_options", {}) for hook in hook_group.hooks: - changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t) + changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options) # if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref; # this will cause the weights to be recalculated when sampling if changed: diff --git a/comfy/samplers.py b/comfy/samplers.py index 27686722d..6a386511a 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -144,7 +144,7 @@ def cond_cat(c_list): return out -def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep): +def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep, model_options): # need to figure out remaining unmasked area for conds default_mults = [] for _ in default_conds: @@ -183,7 +183,7 @@ def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.H # replace p's mult with calculated mult p = p._replace(mult=mult) if p.hooks is not None: - model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks) + model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options) hooked_to_run.setdefault(p.hooks, list()) hooked_to_run[p.hooks] += [(p, i)] @@ -218,7 +218,7 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te if p is None: continue if p.hooks is not None: - model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks) + model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options) hooked_to_run.setdefault(p.hooks, list()) hooked_to_run[p.hooks] += [(p, i)] default_conds.append(default_c) @@ -840,7 +840,9 @@ class CFGGuider: self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed) - extra_args = {"model_options": comfy.model_patcher.create_model_options_clone(self.model_options), "seed": seed} + extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options) + extra_model_options.setdefault("transformer_options", {})["sigmas"] = sigmas + extra_args = {"model_options": extra_model_options, "seed": seed} executor = comfy.patcher_extension.WrapperExecutor.new_class_executor( sampler.sample, From 5a2ad032cb09afcaf7fadf5cdfa20c2b0498aee5 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 3 Jan 2025 20:02:27 -0600 Subject: [PATCH 02/52] Cleaned up hooks.py, refactored Hook.should_register and add_hook_patches to use target_dict instead of target so that more information can be provided about the current execution environment if needed --- comfy/hooks.py | 148 +++++++++++++++++++++++------------- comfy/model_patcher.py | 8 +- comfy/sampler_helpers.py | 2 +- comfy_extras/nodes_hooks.py | 2 +- 4 files changed, 102 insertions(+), 58 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index 79a7090ba..181c4996a 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -16,46 +16,86 @@ import comfy.model_management import comfy.patcher_extension from node_helpers import conditioning_set_values +# ####################################################################################################### +# Hooks explanation +# ------------------- +# The purpose of hooks is to allow conds to influence sampling without the need for ComfyUI core code to +# make explicit special cases like it does for ControlNet and GLIGEN. +# +# This is necessary for nodes/features that are intended for use with masked or scheduled conds, or those +# that should run special code when a 'marked' cond is used in sampling. +# ####################################################################################################### + class EnumHookMode(enum.Enum): + ''' + Priority of hook memory optimization vs. speed, mostly related to WeightHooks. + + MinVram: No caching will occur for any operations related to hooks. + MaxSpeed: Excess VRAM (and RAM, once VRAM is sufficiently depleted) will be used to cache hook weights when switching hook groups. + ''' MinVram = "minvram" MaxSpeed = "maxspeed" class EnumHookType(enum.Enum): + ''' + Hook types, each of which has different expected behavior. + ''' Weight = "weight" Patch = "patch" ObjectPatch = "object_patch" AddModels = "add_models" - Callbacks = "callbacks" Wrappers = "wrappers" - SetInjections = "add_injections" + Injections = "add_injections" class EnumWeightTarget(enum.Enum): Model = "model" Clip = "clip" +class EnumHookScope(enum.Enum): + ''' + Determines if hook should be limited in its influence over sampling. + + AllConditioning: hook will affect all conds used in sampling. + HookedOnly: hook will only affect the conds it was attached to. + ''' + AllConditioning = "all_conditioning" + HookedOnly = "hooked_only" + + class _HookRef: pass -# NOTE: this is an example of how the should_register function should look -def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): + +def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + '''Example for how should_register function should look like.''' return True +def create_target_dict(target: EnumWeightTarget=None, **kwargs) -> dict[str]: + '''Creates base dictionary for use with Hooks' target param.''' + d = {} + if target is not None: + d['target'] = target + d.update(kwargs) + return d + + class Hook: def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None, - hook_keyframe: 'HookKeyframeGroup'=None): + hook_keyframe: HookKeyframeGroup=None, hook_scope=EnumHookScope.AllConditioning): self.hook_type = hook_type self.hook_ref = hook_ref if hook_ref else _HookRef() self.hook_id = hook_id self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup() self.custom_should_register = default_should_register self.auto_apply_to_nonpositive = False + self.hook_scope = hook_scope @property def strength(self): return self.hook_keyframe.strength - def initialize_timesteps(self, model: 'BaseModel'): + def initialize_timesteps(self, model: BaseModel): self.reset() self.hook_keyframe.initialize_timesteps(model) @@ -75,27 +115,32 @@ class Hook: c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive return c - def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): - return self.custom_should_register(self, model, model_options, target, registered) + def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + return self.custom_should_register(self, model, model_options, target_dict, registered) - def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): raise NotImplementedError("add_hook_patches should be defined for Hook subclasses") - def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]): + def on_apply(self, model: ModelPatcher, transformer_options: dict[str]): pass - def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]): + def on_unapply(self, model: ModelPatcher, transformer_options: dict[str]): pass - def __eq__(self, other: 'Hook'): + def __eq__(self, other: Hook): return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref def __hash__(self): return hash(self.hook_ref) class WeightHook(Hook): + ''' + Hook responsible for tracking weights to be applied to some model/clip. + + Note, value of hook_scope is ignored and is treated as HookedOnly. + ''' def __init__(self, strength_model=1.0, strength_clip=1.0): - super().__init__(hook_type=EnumHookType.Weight) + super().__init__(hook_type=EnumHookType.Weight, hook_scope=EnumHookScope.HookedOnly) self.weights: dict = None self.weights_clip: dict = None self.need_weight_init = True @@ -110,27 +155,29 @@ class WeightHook(Hook): def strength_clip(self): return self._strength_clip * self.strength - def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): - if not self.should_register(model, model_options, target, registered): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + if not self.should_register(model, model_options, target_dict, registered): return False weights = None - if target == EnumWeightTarget.Model: - strength = self._strength_model - else: + + target = target_dict.get('target', None) + if target == EnumWeightTarget.Clip: strength = self._strength_clip + else: + strength = self._strength_model if self.need_weight_init: key_map = {} - if target == EnumWeightTarget.Model: - key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) - else: + if target == EnumWeightTarget.Clip: key_map = comfy.lora.model_lora_keys_clip(model.model, key_map) + else: + key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False) else: - if target == EnumWeightTarget.Model: - weights = self.weights - else: + if target == EnumWeightTarget.Clip: weights = self.weights_clip + else: + weights = self.weights model.add_hook_patches(hook=self, patches=weights, strength_patch=strength) registered.append(self) return True @@ -174,7 +221,12 @@ class ObjectPatchHook(Hook): # TODO: add functionality class AddModelsHook(Hook): - def __init__(self, key: str=None, models: list['ModelPatcher']=None): + ''' + Hook responsible for telling model management any additional models that should be loaded. + + Note, value of hook_scope is ignored and is treated as AllConditioning. + ''' + def __init__(self, key: str=None, models: list[ModelPatcher]=None): super().__init__(hook_type=EnumHookType.AddModels) self.key = key self.models = models @@ -188,24 +240,15 @@ class AddModelsHook(Hook): c.models = self.models.copy() if self.models else self.models c.append_when_same = self.append_when_same return c - # TODO: add functionality -class CallbackHook(Hook): - def __init__(self, key: str=None, callback: Callable=None): - super().__init__(hook_type=EnumHookType.Callbacks) - self.key = key - self.callback = callback - - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: CallbackHook = super().clone(subtype) - c.key = self.key - c.callback = self.callback - return c - # TODO: add functionality + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + if not self.should_register(model, model_options, target_dict, registered): + return False class WrapperHook(Hook): + ''' + Hook responsible for adding wrappers, callbacks, or anything else onto transformer_options. + ''' def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None): super().__init__(hook_type=EnumHookType.Wrappers) self.wrappers_dict = wrappers_dict @@ -217,17 +260,18 @@ class WrapperHook(Hook): c.wrappers_dict = self.wrappers_dict return c - def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]): - if not self.should_register(model, model_options, target, registered): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + if not self.should_register(model, model_options, target_dict, registered): return False add_model_options = {"transformer_options": self.wrappers_dict} - comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) + if self.hook_scope == EnumHookScope.AllConditioning: + comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) registered.append(self) return True class SetInjectionsHook(Hook): - def __init__(self, key: str=None, injections: list['PatcherInjection']=None): - super().__init__(hook_type=EnumHookType.SetInjections) + def __init__(self, key: str=None, injections: list[PatcherInjection]=None): + super().__init__(hook_type=EnumHookType.Injections) self.key = key self.injections = injections @@ -239,7 +283,7 @@ class SetInjectionsHook(Hook): c.injections = self.injections.copy() if self.injections else self.injections return c - def add_hook_injections(self, model: 'ModelPatcher'): + def add_hook_injections(self, model: ModelPatcher): # TODO: add functionality pass @@ -260,14 +304,14 @@ class HookGroup: c.add(hook.clone()) return c - def clone_and_combine(self, other: 'HookGroup'): + def clone_and_combine(self, other: HookGroup): c = self.clone() if other is not None: for hook in other.hooks: c.add(hook.clone()) return c - def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'): + def set_keyframes_on_hooks(self, hook_kf: HookKeyframeGroup): if hook_kf is None: hook_kf = HookKeyframeGroup() else: @@ -336,7 +380,7 @@ class HookGroup: hook.reset() @staticmethod - def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup': + def combine_all_hooks(hooks_list: list[HookGroup], require_count=0) -> HookGroup: actual: list[HookGroup] = [] for group in hooks_list: if group is not None: @@ -433,7 +477,7 @@ class HookKeyframeGroup: c._set_first_as_current() return c - def initialize_timesteps(self, model: 'BaseModel'): + def initialize_timesteps(self, model: BaseModel): for keyframe in self.keyframes: keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) @@ -548,7 +592,7 @@ def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float hook.need_weight_init = False return hook_group -def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True): +def get_patch_weights_from_model(model: ModelPatcher, discard_model_sampling=True): if model is None: return None patches_model: dict[str, torch.Tensor] = model.model.state_dict() @@ -560,7 +604,7 @@ def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=T return patches_model # NOTE: this function shows how to register weight hooks directly on the ModelPatchers -def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor], +def load_hook_lora_for_models(model: ModelPatcher, clip: CLIP, lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float): key_map = {} if model is not None: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 4597ce11c..071535526 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -940,13 +940,13 @@ class ModelPatcher: if reset_current_hooks: self.patch_hooks(None) - def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target: comfy.hooks.EnumWeightTarget, model_options: dict=None): + def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target_dict: dict[str], model_options: dict=None): self.restore_hook_patches() registered_hooks: list[comfy.hooks.Hook] = [] # handle WrapperHooks, if model_options provided if model_options is not None: for hook in hooks_dict.get(comfy.hooks.EnumHookType.Wrappers, {}): - hook.add_hook_patches(self, model_options, target, registered_hooks) + hook.add_hook_patches(self, model_options, target_dict, registered_hooks) # handle WeightHooks weight_hooks_to_register: list[comfy.hooks.WeightHook] = [] for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}): @@ -956,9 +956,9 @@ class ModelPatcher: # clone hook_patches to become backup so that any non-dynamic hooks will return to their original state self.hook_patches_backup = create_hook_patches_clone(self.hook_patches) for hook in weight_hooks_to_register: - hook.add_hook_patches(self, model_options, target, registered_hooks) + hook.add_hook_patches(self, model_options, target_dict, registered_hooks) for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES): - callback(self, hooks_dict, target) + callback(self, hooks_dict, target_dict) def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0): with self.use_ejected(): diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index ac9735369..6f21ca3cf 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -131,4 +131,4 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers) model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks) # register hooks on model/model_options - model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model, model_options) + model.register_all_hook_patches(hooks, comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model), model_options) diff --git a/comfy_extras/nodes_hooks.py b/comfy_extras/nodes_hooks.py index 9d9d48378..49b90b9d5 100644 --- a/comfy_extras/nodes_hooks.py +++ b/comfy_extras/nodes_hooks.py @@ -255,7 +255,7 @@ class SetClipHooks: clip.use_clip_schedule = schedule_clip if not clip.use_clip_schedule: clip.patcher.forced_hooks.set_keyframes_on_hooks(None) - clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.EnumWeightTarget.Clip) + clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Clip)) return (clip,) class ConditioningTimestepsRange: From 776aa734e1ac0a46fefef6abcc5ad29763003a7e Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 4 Jan 2025 01:02:21 -0600 Subject: [PATCH 03/52] Refactor WrapperHook into TransformerOptionsHook, as there is no need to separate out Wrappers/Callbacks/Patches into different hook types (all affect transformer_options) --- comfy/hooks.py | 24 +++++++++++++++++------- comfy/model_patcher.py | 2 +- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index 181c4996a..7ca3a8a11 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -44,7 +44,7 @@ class EnumHookType(enum.Enum): Patch = "patch" ObjectPatch = "object_patch" AddModels = "add_models" - Wrappers = "wrappers" + TransformerOptions = "transformer_options" Injections = "add_injections" class EnumWeightTarget(enum.Enum): @@ -245,29 +245,39 @@ class AddModelsHook(Hook): if not self.should_register(model, model_options, target_dict, registered): return False -class WrapperHook(Hook): +class TransformerOptionsHook(Hook): ''' - Hook responsible for adding wrappers, callbacks, or anything else onto transformer_options. + Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options. ''' def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None): - super().__init__(hook_type=EnumHookType.Wrappers) - self.wrappers_dict = wrappers_dict + super().__init__(hook_type=EnumHookType.TransformerOptions) + self.transformers_dict = wrappers_dict def clone(self, subtype: Callable=None): if subtype is None: subtype = type(self) c: WrapperHook = super().clone(subtype) - c.wrappers_dict = self.wrappers_dict + c.transformers_dict = self.transformers_dict return c def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): if not self.should_register(model, model_options, target_dict, registered): return False - add_model_options = {"transformer_options": self.wrappers_dict} + add_model_options = {"transformer_options": self.transformers_dict} + # TODO: call .to on patches/anything else in transformer_options that is expected to do something if self.hook_scope == EnumHookScope.AllConditioning: comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) registered.append(self) return True + + def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]): + comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False) + +class WrapperHook(TransformerOptionsHook): + ''' + For backwards compatibility, this hook is identical to TransformerOptionsHook. + ''' + pass class SetInjectionsHook(Hook): def __init__(self, key: str=None, injections: list[PatcherInjection]=None): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 071535526..2db21bdc4 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -945,7 +945,7 @@ class ModelPatcher: registered_hooks: list[comfy.hooks.Hook] = [] # handle WrapperHooks, if model_options provided if model_options is not None: - for hook in hooks_dict.get(comfy.hooks.EnumHookType.Wrappers, {}): + for hook in hooks_dict.get(comfy.hooks.EnumHookType.TransformerOptions, {}): hook.add_hook_patches(self, model_options, target_dict, registered_hooks) # handle WeightHooks weight_hooks_to_register: list[comfy.hooks.WeightHook] = [] From 111fd0cadfe83cdda7a1a775f89e0dd675a58d66 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 4 Jan 2025 02:04:07 -0600 Subject: [PATCH 04/52] Refactored HookGroup to also store a dictionary of hooks separated by hook_type, modified necessary code to no longer need to manually separate out hooks by hook_type --- comfy/hooks.py | 78 ++++++++++++++++++------------------- comfy/model_patcher.py | 10 ++--- comfy/sampler_helpers.py | 20 +++++----- comfy_extras/nodes_hooks.py | 2 +- 4 files changed, 53 insertions(+), 57 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index 7ca3a8a11..9ccfaa6d1 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -41,7 +41,6 @@ class EnumHookType(enum.Enum): Hook types, each of which has different expected behavior. ''' Weight = "weight" - Patch = "patch" ObjectPatch = "object_patch" AddModels = "add_models" TransformerOptions = "transformer_options" @@ -194,19 +193,6 @@ class WeightHook(Hook): c._strength_clip = self._strength_clip return c -class PatchHook(Hook): - def __init__(self): - super().__init__(hook_type=EnumHookType.Patch) - self.patches: dict = None - - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: PatchHook = super().clone(subtype) - c.patches = self.patches - return c - # TODO: add functionality - class ObjectPatchHook(Hook): def __init__(self): super().__init__(hook_type=EnumHookType.ObjectPatch) @@ -244,6 +230,7 @@ class AddModelsHook(Hook): def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): if not self.should_register(model, model_options, target_dict, registered): return False + return True class TransformerOptionsHook(Hook): ''' @@ -298,12 +285,28 @@ class SetInjectionsHook(Hook): pass class HookGroup: + ''' + Stores groups of hooks, and allows them to be queried by type. + + To prevent breaking their functionality, never modify the underlying self.hooks or self._hook_dict vars directly; + always use the provided functions on HookGroup. + ''' def __init__(self): self.hooks: list[Hook] = [] + self._hook_dict: dict[EnumHookType, list[Hook]] = {} def add(self, hook: Hook): if hook not in self.hooks: self.hooks.append(hook) + self._hook_dict.setdefault(hook.hook_type, []).append(hook) + + def remove(self, hook: Hook): + if hook in self.hooks: + self.hooks.remove(hook) + self._hook_dict[hook.hook_type].remove(hook) + + def get_type(self, hook_type: EnumHookType): + return self._hook_dict.get(hook_type, []) def contains(self, hook: Hook): return hook in self.hooks @@ -329,36 +332,29 @@ class HookGroup: for hook in self.hooks: hook.hook_keyframe = hook_kf - def get_dict_repr(self): - d: dict[EnumHookType, dict[Hook, None]] = {} - for hook in self.hooks: - with_type = d.setdefault(hook.hook_type, {}) - with_type[hook] = None - return d - def get_hooks_for_clip_schedule(self): scheduled_hooks: dict[WeightHook, list[tuple[tuple[float,float], HookKeyframe]]] = {} - for hook in self.hooks: - # only care about WeightHooks, for now - if hook.hook_type == EnumHookType.Weight: - hook_schedule = [] - # if no hook keyframes, assign default value - if len(hook.hook_keyframe.keyframes) == 0: - hook_schedule.append(((0.0, 1.0), None)) - scheduled_hooks[hook] = hook_schedule - continue - # find ranges of values - prev_keyframe = hook.hook_keyframe.keyframes[0] - for keyframe in hook.hook_keyframe.keyframes: - if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength): - hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe)) - prev_keyframe = keyframe - elif keyframe.start_percent == prev_keyframe.start_percent: - prev_keyframe = keyframe - # create final range, assuming last start_percent was not 1.0 - if not math.isclose(prev_keyframe.start_percent, 1.0): - hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe)) + # only care about WeightHooks, for now + for hook in self.get_type(EnumHookType.Weight): + hook: WeightHook + hook_schedule = [] + # if no hook keyframes, assign default value + if len(hook.hook_keyframe.keyframes) == 0: + hook_schedule.append(((0.0, 1.0), None)) scheduled_hooks[hook] = hook_schedule + continue + # find ranges of values + prev_keyframe = hook.hook_keyframe.keyframes[0] + for keyframe in hook.hook_keyframe.keyframes: + if keyframe.start_percent > prev_keyframe.start_percent and not math.isclose(keyframe.strength, prev_keyframe.strength): + hook_schedule.append(((prev_keyframe.start_percent, keyframe.start_percent), prev_keyframe)) + prev_keyframe = keyframe + elif keyframe.start_percent == prev_keyframe.start_percent: + prev_keyframe = keyframe + # create final range, assuming last start_percent was not 1.0 + if not math.isclose(prev_keyframe.start_percent, 1.0): + hook_schedule.append(((prev_keyframe.start_percent, 1.0), prev_keyframe)) + scheduled_hooks[hook] = hook_schedule # hooks should not have their schedules in a list of tuples all_ranges: list[tuple[float, float]] = [] for range_kfs in scheduled_hooks.values(): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 2db21bdc4..0430430e5 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -940,16 +940,16 @@ class ModelPatcher: if reset_current_hooks: self.patch_hooks(None) - def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target_dict: dict[str], model_options: dict=None): + def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None): self.restore_hook_patches() registered_hooks: list[comfy.hooks.Hook] = [] - # handle WrapperHooks, if model_options provided + # handle TransformerOptionsHooks, if model_options provided if model_options is not None: - for hook in hooks_dict.get(comfy.hooks.EnumHookType.TransformerOptions, {}): + for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions): hook.add_hook_patches(self, model_options, target_dict, registered_hooks) # handle WeightHooks weight_hooks_to_register: list[comfy.hooks.WeightHook] = [] - for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}): + for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight): if hook.hook_ref not in self.hook_patches: weight_hooks_to_register.append(hook) if len(weight_hooks_to_register) > 0: @@ -958,7 +958,7 @@ class ModelPatcher: for hook in weight_hooks_to_register: hook.add_hook_patches(self, model_options, target_dict, registered_hooks) for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES): - callback(self, hooks_dict, target_dict) + callback(self, hooks, target_dict) def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0): with self.use_ejected(): diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 6f21ca3cf..abd44cf6e 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -24,15 +24,13 @@ def get_models_from_cond(cond, model_type): models += [c[model_type]] return models -def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]]): +def get_hooks_from_cond(cond, full_hooks: comfy.hooks.HookGroup): # get hooks from conds, and collect cnets so they can be checked for extra_hooks cnets: list[ControlBase] = [] for c in cond: if 'hooks' in c: for hook in c['hooks'].hooks: - hook: comfy.hooks.Hook - with_type = hooks_dict.setdefault(hook.hook_type, {}) - with_type[hook] = None + full_hooks.add(hook) if 'control' in c: cnets.append(c['control']) @@ -50,10 +48,9 @@ def get_hooks_from_cond(cond, hooks_dict: dict[comfy.hooks.EnumHookType, dict[co extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list) if extra_hooks is not None: for hook in extra_hooks.hooks: - with_type = hooks_dict.setdefault(hook.hook_type, {}) - with_type[hook] = None + full_hooks.add(hook) - return hooks_dict + return full_hooks def convert_cond(cond): out = [] @@ -73,7 +70,7 @@ def get_additional_models(conds, dtype): cnets: list[ControlBase] = [] gligen = [] add_models = [] - hooks: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]] = {} + hooks = comfy.hooks.HookGroup() for k in conds: cnets += get_models_from_cond(conds[k], "control") @@ -90,7 +87,10 @@ def get_additional_models(conds, dtype): inference_memory += m.inference_memory_requirements(dtype) gligen = [x[1] for x in gligen] - hook_models = [x.model for x in hooks.get(comfy.hooks.EnumHookType.AddModels, {}).keys()] + hook_models = [] + for x in hooks.get_type(comfy.hooks.EnumHookType.AddModels): + x: comfy.hooks.AddModelsHook + hook_models.extend(x.models) models = control_models + gligen + add_models + hook_models return models, inference_memory @@ -124,7 +124,7 @@ def cleanup_models(conds, models): def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): # check for hooks in conds - if not registered, see if can be applied - hooks = {} + hooks = comfy.hooks.HookGroup() for k in conds: get_hooks_from_cond(conds[k], hooks) # add wrappers and callbacks from ModelPatcher to transformer_options diff --git a/comfy_extras/nodes_hooks.py b/comfy_extras/nodes_hooks.py index 49b90b9d5..642238340 100644 --- a/comfy_extras/nodes_hooks.py +++ b/comfy_extras/nodes_hooks.py @@ -255,7 +255,7 @@ class SetClipHooks: clip.use_clip_schedule = schedule_clip if not clip.use_clip_schedule: clip.patcher.forced_hooks.set_keyframes_on_hooks(None) - clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Clip)) + clip.patcher.register_all_hook_patches(hooks, comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Clip)) return (clip,) class ConditioningTimestepsRange: From 6620d86318d19562a4410eabc78c27538d54e445 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 5 Jan 2025 15:26:22 -0600 Subject: [PATCH 05/52] In inner_sample, change "sigmas" to "sampler_sigmas" in transformer_options to not conflict with the "sigmas" that will overwrite "sigmas" in _calc_cond_batch --- comfy/hooks.py | 2 +- comfy/samplers.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index 79a7090ba..3cb0f3963 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -442,7 +442,7 @@ class HookKeyframeGroup: return False if curr_t == self._curr_t: return False - max_sigma = torch.max(transformer_options["sigmas"]) + max_sigma = torch.max(transformer_options["sample_sigmas"]) prev_index = self._current_index prev_strength = self._current_strength # if met guaranteed steps, look for next keyframe in case need to switch diff --git a/comfy/samplers.py b/comfy/samplers.py index 89464a42a..af2b8e110 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -849,7 +849,7 @@ class CFGGuider: self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed) extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options) - extra_model_options.setdefault("transformer_options", {})["sigmas"] = sigmas + extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas extra_args = {"model_options": extra_model_options, "seed": seed} executor = comfy.patcher_extension.WrapperExecutor.new_class_executor( From 8270ff312f7aefc4d29aeeed667296b2a56628ce Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 5 Jan 2025 21:07:02 -0600 Subject: [PATCH 06/52] Refactored 'registered' to be HookGroup instead of a list of Hooks, made AddModelsHook operational and compliant with should_register result, moved TransformerOptionsHook handling out of ModelPatcher.register_all_hook_patches, support patches in TransformerOptionsHook properly by casting any patches/wrappers/hooks to proper device at sample time --- comfy/hooks.py | 34 +++++++++++++++--------- comfy/model_patcher.py | 15 +++++------ comfy/sampler_helpers.py | 48 +++++++++++++++++++++++++-------- comfy/samplers.py | 57 +++++++++++++++++++++++++++++++++++++--- 4 files changed, 119 insertions(+), 35 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index 3ead8c963..25d67b86c 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -65,7 +65,7 @@ class _HookRef: pass -def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): +def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): '''Example for how should_register function should look like.''' return True @@ -114,10 +114,10 @@ class Hook: c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive return c - def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): return self.custom_should_register(self, model, model_options, target_dict, registered) - def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): raise NotImplementedError("add_hook_patches should be defined for Hook subclasses") def on_apply(self, model: ModelPatcher, transformer_options: dict[str]): @@ -154,7 +154,7 @@ class WeightHook(Hook): def strength_clip(self): return self._strength_clip * self.strength - def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): if not self.should_register(model, model_options, target_dict, registered): return False weights = None @@ -178,7 +178,7 @@ class WeightHook(Hook): else: weights = self.weights model.add_hook_patches(hook=self, patches=weights, strength_patch=strength) - registered.append(self) + registered.add(self) return True # TODO: add logs about any keys that were not applied @@ -212,11 +212,12 @@ class AddModelsHook(Hook): Note, value of hook_scope is ignored and is treated as AllConditioning. ''' - def __init__(self, key: str=None, models: list[ModelPatcher]=None): + def __init__(self, models: list[ModelPatcher]=None, key: str=None): super().__init__(hook_type=EnumHookType.AddModels) - self.key = key self.models = models + self.key = key self.append_when_same = True + '''Curently does nothing.''' def clone(self, subtype: Callable=None): if subtype is None: @@ -227,9 +228,10 @@ class AddModelsHook(Hook): c.append_when_same = self.append_when_same return c - def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): if not self.should_register(model, model_options, target_dict, registered): return False + registered.add(self) return True class TransformerOptionsHook(Hook): @@ -247,14 +249,17 @@ class TransformerOptionsHook(Hook): c.transformers_dict = self.transformers_dict return c - def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]): + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): if not self.should_register(model, model_options, target_dict, registered): return False - add_model_options = {"transformer_options": self.transformers_dict} - # TODO: call .to on patches/anything else in transformer_options that is expected to do something + # NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks if self.hook_scope == EnumHookScope.AllConditioning: - comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) - registered.append(self) + add_model_options = {"transformer_options": self.transformers_dict, + "to_load_options": self.transformers_dict} + else: + add_model_options = {"to_load_options": self.transformers_dict} + comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) + registered.add(self) return True def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]): @@ -295,6 +300,9 @@ class HookGroup: self.hooks: list[Hook] = [] self._hook_dict: dict[EnumHookType, list[Hook]] = {} + def __len__(self): + return len(self.hooks) + def add(self, hook: Hook): if hook not in self.hooks: self.hooks.append(hook) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 0430430e5..2a5510873 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -940,13 +940,11 @@ class ModelPatcher: if reset_current_hooks: self.patch_hooks(None) - def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None): + def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None, + registered: comfy.hooks.HookGroup = None): self.restore_hook_patches() - registered_hooks: list[comfy.hooks.Hook] = [] - # handle TransformerOptionsHooks, if model_options provided - if model_options is not None: - for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions): - hook.add_hook_patches(self, model_options, target_dict, registered_hooks) + if registered is None: + registered = comfy.hooks.HookGroup() # handle WeightHooks weight_hooks_to_register: list[comfy.hooks.WeightHook] = [] for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight): @@ -956,9 +954,10 @@ class ModelPatcher: # clone hook_patches to become backup so that any non-dynamic hooks will return to their original state self.hook_patches_backup = create_hook_patches_clone(self.hook_patches) for hook in weight_hooks_to_register: - hook.add_hook_patches(self, model_options, target_dict, registered_hooks) + hook.add_hook_patches(self, model_options, target_dict, registered) for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES): - callback(self, hooks, target_dict) + callback(self, hooks, target_dict, model_options, registered) + return registered def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0): with self.use_ejected(): diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index abd44cf6e..cb9388519 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -70,13 +70,11 @@ def get_additional_models(conds, dtype): cnets: list[ControlBase] = [] gligen = [] add_models = [] - hooks = comfy.hooks.HookGroup() for k in conds: cnets += get_models_from_cond(conds[k], "control") gligen += get_models_from_cond(conds[k], "gligen") add_models += get_models_from_cond(conds[k], "additional_models") - get_hooks_from_cond(conds[k], hooks) control_nets = set(cnets) @@ -87,14 +85,20 @@ def get_additional_models(conds, dtype): inference_memory += m.inference_memory_requirements(dtype) gligen = [x[1] for x in gligen] - hook_models = [] - for x in hooks.get_type(comfy.hooks.EnumHookType.AddModels): - x: comfy.hooks.AddModelsHook - hook_models.extend(x.models) - models = control_models + gligen + add_models + hook_models + models = control_models + gligen + add_models return models, inference_memory +def get_additional_models_from_model_options(model_options: dict[str]=None): + """loads additional models from registered AddModels hooks""" + models = [] + if model_options is not None and "registered_hooks" in model_options: + registered: comfy.hooks.HookGroup = model_options["registered_hooks"] + for hook in registered.get_type(comfy.hooks.EnumHookType.AddModels): + hook: comfy.hooks.AddModelsHook + models.extend(hook.models) + return models + def cleanup_additional_models(models): """cleanup additional models that were loaded""" for m in models: @@ -102,9 +106,10 @@ def cleanup_additional_models(models): m.cleanup() -def prepare_sampling(model: 'ModelPatcher', noise_shape, conds): - real_model: 'BaseModel' = None +def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None): + real_model: BaseModel = None models, inference_memory = get_additional_models(conds, model.model_dtype()) + models += get_additional_models_from_model_options(model_options) models += model.get_nested_additional_models() # TODO: does this require inference_memory update? memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory @@ -130,5 +135,26 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): # add wrappers and callbacks from ModelPatcher to transformer_options model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers) model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks) - # register hooks on model/model_options - model.register_all_hook_patches(hooks, comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model), model_options) + # begin registering hooks + registered = comfy.hooks.HookGroup() + target_dict = comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model) + # handle all TransformerOptionsHooks + for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions): + hook: comfy.hooks.TransformerOptionsHook + hook.add_hook_patches(model, model_options, target_dict, registered) + # handle all AddModelsHooks + for hook in hooks.get_type(comfy.hooks.EnumHookType.AddModels): + hook: comfy.hooks.AddModelsHook + hook.add_hook_patches(model, model_options, target_dict, registered) + # handle all WeightHooks by registering on ModelPatcher + model.register_all_hook_patches(hooks, target_dict, model_options, registered) + # add registered_hooks onto model_options for further reference + if len(registered) > 0: + model_options["registered_hooks"] = registered + # merge original wrappers and callbacks with hooked wrappers and callbacks + to_load_options: dict[str] = model_options.setdefault("to_load_options", {}) + for wc_name in ["wrappers", "callbacks"]: + comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name], + copy_dict1=False) + return to_load_options + diff --git a/comfy/samplers.py b/comfy/samplers.py index af2b8e110..8f8345abc 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -819,9 +819,58 @@ def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]): return len(hooks_set) +def cast_to_load_options(model_options: dict[str], device=None, dtype=None): + ''' + If any patches from hooks, wrappers, or callbacks have .to to be called, call it. + ''' + if model_options is None: + return + to_load_options = model_options.get("to_load_options", None) + if to_load_options is None: + return + + casts = [] + if device is not None: + casts.append(device) + if dtype is not None: + casts.append(dtype) + # if nothing to apply, do nothing + if len(casts) == 0: + return + + # Try to call .to on patches + if "patches" in to_load_options: + patches = to_load_options["patches"] + for name in patches: + patch_list = patches[name] + for i in range(len(patch_list)): + if hasattr(patch_list[i], "to"): + for cast in casts: + patch_list[i] = patch_list[i].to(cast) + if "patches_replace" in to_load_options: + patches = to_load_options["patches_replace"] + for name in patches: + patch_list = patches[name] + for k in patch_list: + if hasattr(patch_list[k], "to"): + for cast in casts: + patch_list[k] = patch_list[k].to(cast) + # Try to call .to on any wrappers/callbacks + wrappers_and_callbacks = ["wrappers", "callbacks"] + for wc_name in wrappers_and_callbacks: + if wc_name in to_load_options: + wc: dict[str, list] = to_load_options[wc_name] + for wc_dict in wc.values(): + for wc_list in wc_dict.values(): + for i in range(len(wc_list)): + if hasattr(wc_list[i], "to"): + for cast in casts: + wc_list[i] = wc_list[i].to(cast) + + class CFGGuider: - def __init__(self, model_patcher): - self.model_patcher: 'ModelPatcher' = model_patcher + def __init__(self, model_patcher: ModelPatcher): + self.model_patcher = model_patcher self.model_options = model_patcher.model_options self.original_conds = {} self.cfg = 1.0 @@ -861,7 +910,7 @@ class CFGGuider: return self.inner_model.process_latent_out(samples.to(torch.float32)) def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): - self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds) + self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device if denoise_mask is not None: @@ -870,6 +919,7 @@ class CFGGuider: noise = noise.to(device) latent_image = latent_image.to(device) sigmas = sigmas.to(device) + cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype()) try: self.model_patcher.pre_run() @@ -906,6 +956,7 @@ class CFGGuider: ) output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) finally: + cast_to_load_options(self.model_options, device=self.model_patcher.offload_device) self.model_options = orig_model_options self.model_patcher.hook_mode = orig_hook_mode self.model_patcher.restore_hook_patches() From 4446c86052bd9a00b72205b761b3744dd51f90eb Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 5 Jan 2025 22:25:51 -0600 Subject: [PATCH 07/52] Made hook clone code sane, made clear ObjectPatchHook and SetInjectionsHook are not yet operational --- comfy/hooks.py | 64 +++++++++++++++++++--------------------- comfy/sampler_helpers.py | 1 - 2 files changed, 31 insertions(+), 34 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index 25d67b86c..b62092cce 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -101,10 +101,8 @@ class Hook: def reset(self): self.hook_keyframe.reset() - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: Hook = subtype() + def clone(self): + c: Hook = self.__class__() c.hook_type = self.hook_type c.hook_ref = self.hook_ref c.hook_id = self.hook_id @@ -182,10 +180,8 @@ class WeightHook(Hook): return True # TODO: add logs about any keys that were not applied - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: WeightHook = super().clone(subtype) + def clone(self): + c: WeightHook = super().clone() c.weights = self.weights c.weights_clip = self.weights_clip c.need_weight_init = self.need_weight_init @@ -194,17 +190,21 @@ class WeightHook(Hook): return c class ObjectPatchHook(Hook): - def __init__(self): + def __init__(self, object_patches: dict[str]=None): super().__init__(hook_type=EnumHookType.ObjectPatch) - self.object_patches: dict = None + self.object_patches = object_patches - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: ObjectPatchHook = super().clone(subtype) + def clone(self): + c: ObjectPatchHook = super().clone() c.object_patches = self.object_patches return c - # TODO: add functionality + + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): + raise NotImplementedError("ObjectPatchHook is not supported yet in ComfyUI.") + if not self.should_register(model, model_options, target_dict, registered): + return False + registered.add(self) + return True class AddModelsHook(Hook): ''' @@ -219,12 +219,10 @@ class AddModelsHook(Hook): self.append_when_same = True '''Curently does nothing.''' - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: AddModelsHook = super().clone(subtype) - c.key = self.key + def clone(self): + c: AddModelsHook = super().clone() c.models = self.models.copy() if self.models else self.models + c.key = self.key c.append_when_same = self.append_when_same return c @@ -242,10 +240,8 @@ class TransformerOptionsHook(Hook): super().__init__(hook_type=EnumHookType.TransformerOptions) self.transformers_dict = wrappers_dict - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: WrapperHook = super().clone(subtype) + def clone(self): + c: TransformerOptionsHook = super().clone() c.transformers_dict = self.transformers_dict return c @@ -265,11 +261,8 @@ class TransformerOptionsHook(Hook): def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]): comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False) -class WrapperHook(TransformerOptionsHook): - ''' - For backwards compatibility, this hook is identical to TransformerOptionsHook. - ''' - pass +WrapperHook = TransformerOptionsHook +'''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.''' class SetInjectionsHook(Hook): def __init__(self, key: str=None, injections: list[PatcherInjection]=None): @@ -277,14 +270,19 @@ class SetInjectionsHook(Hook): self.key = key self.injections = injections - def clone(self, subtype: Callable=None): - if subtype is None: - subtype = type(self) - c: SetInjectionsHook = super().clone(subtype) + def clone(self): + c: SetInjectionsHook = super().clone() c.key = self.key c.injections = self.injections.copy() if self.injections else self.injections return c + def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): + raise NotImplementedError("SetInjectionsHook is not supported yet in ComfyUI.") + if not self.should_register(model, model_options, target_dict, registered): + return False + registered.add(self) + return True + def add_hook_injections(self, model: ModelPatcher): # TODO: add functionality pass diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index cb9388519..d43280fe4 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -157,4 +157,3 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name], copy_dict1=False) return to_load_options - From 03a97b604a3e8ca9f54c711ed3b007f07c9115ba Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 6 Jan 2025 01:03:59 -0600 Subject: [PATCH 08/52] Fix performance of hooks when hooks are appended via Cond Pair Set Props nodes by properly caching between positive and negative conds, make hook_patches_backup behave as intended (in the case that something pre-registers WeightHooks on the ModelPatcher instead of registering it at sample time) --- comfy/hooks.py | 33 +++++++++++++++++++++++++-------- comfy/model_patcher.py | 10 ++++++---- 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index b62092cce..dde3e8bcb 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -317,6 +317,18 @@ class HookGroup: def contains(self, hook: Hook): return hook in self.hooks + def is_subset_of(self, other: HookGroup): + self_hooks = set(self.hooks) + other_hooks = set(other.hooks) + return self_hooks.issubset(other_hooks) + + def new_with_common_hooks(self, other: HookGroup): + c = HookGroup() + for hook in self.hooks: + if other.contains(hook): + c.add(hook.clone()) + return c + def clone(self): c = HookGroup() for hook in self.hooks: @@ -668,24 +680,26 @@ def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, H else: c_dict[hooks_key] = cache[hooks_tuple] -def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True): +def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True, + cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None): c = [] - hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {} + if cache is None: + cache = {} for t in conditioning: n = [t[0], t[1].copy()] for k in values: if append_hooks and k == 'hooks': - _combine_hooks_from_values(n[1], values, hooks_combine_cache) + _combine_hooks_from_values(n[1], values, cache) else: n[1][k] = values[k] c.append(n) return c -def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True): +def set_hooks_for_conditioning(cond, hooks: HookGroup, append_hooks=True, cache: dict[tuple[HookGroup, HookGroup], HookGroup]=None): if hooks is None: return cond - return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks) + return conditioning_set_values_with_hooks(cond, {'hooks': hooks}, append_hooks=append_hooks, cache=cache) def set_timesteps_for_conditioning(cond, timestep_range: tuple[float,float]): if timestep_range is None: @@ -720,9 +734,10 @@ def combine_with_new_conds(conds: list, new_conds: list): def set_conds_props(conds: list, strength: float, set_cond_area: str, mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): final_conds = [] + cache = {} for c in conds: # first, apply lora_hook to conditioning, if provided - c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks) + c = set_hooks_for_conditioning(c, hooks, append_hooks=append_hooks, cache=cache) # next, apply mask to conditioning c = set_mask_for_conditioning(cond=c, mask=mask, strength=strength, set_cond_area=set_cond_area) # apply timesteps, if present @@ -734,9 +749,10 @@ def set_conds_props(conds: list, strength: float, set_cond_area: str, def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default", mask: torch.Tensor=None, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): combined_conds = [] + cache = {} for c, masked_c in zip(conds, new_conds): # first, apply lora_hook to new conditioning, if provided - masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks) + masked_c = set_hooks_for_conditioning(masked_c, hooks, append_hooks=append_hooks, cache=cache) # next, apply mask to new conditioning, if provided masked_c = set_mask_for_conditioning(cond=masked_c, mask=mask, set_cond_area=set_cond_area, strength=strength) # apply timesteps, if present @@ -748,9 +764,10 @@ def set_conds_props_and_combine(conds: list, new_conds: list, strength: float=1. def set_default_conds_and_combine(conds: list, new_conds: list, hooks: HookGroup=None, timesteps_range: tuple[float,float]=None, append_hooks=True): combined_conds = [] + cache = {} for c, new_c in zip(conds, new_conds): # first, apply lora_hook to new conditioning, if provided - new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks) + new_c = set_hooks_for_conditioning(new_c, hooks, append_hooks=append_hooks, cache=cache) # next, add default_cond key to cond so that during sampling, it can be identified new_c = conditioning_set_values(new_c, {'default': True}) # apply timesteps, if present diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 2a5510873..57a843b8f 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -210,7 +210,7 @@ class ModelPatcher: self.injections: dict[str, list[PatcherInjection]] = {} self.hook_patches: dict[comfy.hooks._HookRef] = {} - self.hook_patches_backup: dict[comfy.hooks._HookRef] = {} + self.hook_patches_backup: dict[comfy.hooks._HookRef] = None self.hook_backup: dict[str, tuple[torch.Tensor, torch.device]] = {} self.cached_hook_patches: dict[comfy.hooks.HookGroup, dict[str, torch.Tensor]] = {} self.current_hooks: Optional[comfy.hooks.HookGroup] = None @@ -282,7 +282,7 @@ class ModelPatcher: n.injections[k] = i.copy() # hooks n.hook_patches = create_hook_patches_clone(self.hook_patches) - n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) + n.hook_patches_backup = create_hook_patches_clone(self.hook_patches_backup) if self.hook_patches_backup else self.hook_patches_backup for group in self.cached_hook_patches: n.cached_hook_patches[group] = {} for k in self.cached_hook_patches[group]: @@ -912,9 +912,9 @@ class ModelPatcher: callback(self, timestep) def restore_hook_patches(self): - if len(self.hook_patches_backup) > 0: + if self.hook_patches_backup is not None: self.hook_patches = self.hook_patches_backup - self.hook_patches_backup = {} + self.hook_patches_backup = None def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode): self.hook_mode = hook_mode @@ -950,6 +950,8 @@ class ModelPatcher: for hook in hooks.get_type(comfy.hooks.EnumHookType.Weight): if hook.hook_ref not in self.hook_patches: weight_hooks_to_register.append(hook) + else: + registered.add(hook) if len(weight_hooks_to_register) > 0: # clone hook_patches to become backup so that any non-dynamic hooks will return to their original state self.hook_patches_backup = create_hook_patches_clone(self.hook_patches) From 0a7e2ae787b81035798ad2ef1ade8cf882d67b69 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 6 Jan 2025 01:04:29 -0600 Subject: [PATCH 09/52] Filter only registered hooks on self.conds in CFGGuider.sample --- comfy/sampler_helpers.py | 3 +++ comfy/samplers.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index d43280fe4..1433d1859 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -128,6 +128,9 @@ def cleanup_models(conds, models): cleanup_additional_models(set(control_cleanup)) def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): + ''' + Registers hooks from conds. + ''' # check for hooks in conds - if not registered, see if can be applied hooks = comfy.hooks.HookGroup() for k in conds: diff --git a/comfy/samplers.py b/comfy/samplers.py index 8f8345abc..43a735c6e 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -810,6 +810,33 @@ def preprocess_conds_hooks(conds: dict[str, list[dict[str]]]): for cond in conds_to_modify: cond['hooks'] = hooks +def filter_registered_hooks_on_conds(conds: dict[str, list[dict[str]]], model_options: dict[str]): + '''Modify 'hooks' on conds so that only hooks that were registered remain. Properly accounts for + HookGroups that have the same reference.''' + registered: comfy.hooks.HookGroup = model_options.get('registered_hooks', None) + # if None were registered, make sure all hooks are cleaned from conds + if registered is None: + for k in conds: + for kk in conds[k]: + kk.pop('hooks', None) + return + # find conds that contain hooks to be replaced - group by common HookGroup refs + hook_replacement: dict[comfy.hooks.HookGroup, list[dict]] = {} + for k in conds: + for kk in conds[k]: + hooks: comfy.hooks.HookGroup = kk.get('hooks', None) + if hooks is not None: + if not hooks.is_subset_of(registered): + to_replace = hook_replacement.setdefault(hooks, []) + to_replace.append(kk) + # for each hook to replace, create a new proper HookGroup and assign to all common conds + for hooks, conds_to_modify in hook_replacement.items(): + new_hooks = hooks.new_with_common_hooks(registered) + if len(new_hooks) == 0: + new_hooks = None + for kk in conds_to_modify: + kk['hooks'] = new_hooks + def get_total_hook_groups_in_conds(conds: dict[str, list[dict[str]]]): hooks_set = set() @@ -949,6 +976,7 @@ class CFGGuider: if get_total_hook_groups_in_conds(self.conds) <= 1: self.model_patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds, self.model_options) + filter_registered_hooks_on_conds(self.conds, self.model_options) executor = comfy.patcher_extension.WrapperExecutor.new_class_executor( self.outer_sample, self, From f48f90e471fc5440135e7886d712518467c59c00 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 6 Jan 2025 02:23:04 -0600 Subject: [PATCH 10/52] Make hook_scope functional for TransformerOptionsHook --- comfy/hooks.py | 41 ++++++++++++++++++++++++++--------------- comfy/model_patcher.py | 4 ++-- 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index dde3e8bcb..cc9f6cd54 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -86,9 +86,9 @@ class Hook: self.hook_ref = hook_ref if hook_ref else _HookRef() self.hook_id = hook_id self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup() + self.hook_scope = hook_scope self.custom_should_register = default_should_register self.auto_apply_to_nonpositive = False - self.hook_scope = hook_scope @property def strength(self): @@ -107,6 +107,7 @@ class Hook: c.hook_ref = self.hook_ref c.hook_id = self.hook_id c.hook_keyframe = self.hook_keyframe + c.hook_scope = self.hook_scope c.custom_should_register = self.custom_should_register # TODO: make this do something c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive @@ -118,12 +119,6 @@ class Hook: def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): raise NotImplementedError("add_hook_patches should be defined for Hook subclasses") - def on_apply(self, model: ModelPatcher, transformer_options: dict[str]): - pass - - def on_unapply(self, model: ModelPatcher, transformer_options: dict[str]): - pass - def __eq__(self, other: Hook): return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref @@ -143,6 +138,7 @@ class WeightHook(Hook): self.need_weight_init = True self._strength_model = strength_model self._strength_clip = strength_clip + self.hook_scope = EnumHookScope.HookedOnly # this value does not matter for WeightHooks, just for docs @property def strength_model(self): @@ -190,9 +186,11 @@ class WeightHook(Hook): return c class ObjectPatchHook(Hook): - def __init__(self, object_patches: dict[str]=None): + def __init__(self, object_patches: dict[str]=None, + hook_scope=EnumHookScope.AllConditioning): super().__init__(hook_type=EnumHookType.ObjectPatch) self.object_patches = object_patches + self.hook_scope = hook_scope def clone(self): c: ObjectPatchHook = super().clone() @@ -216,14 +214,11 @@ class AddModelsHook(Hook): super().__init__(hook_type=EnumHookType.AddModels) self.models = models self.key = key - self.append_when_same = True - '''Curently does nothing.''' def clone(self): c: AddModelsHook = super().clone() c.models = self.models.copy() if self.models else self.models c.key = self.key - c.append_when_same = self.append_when_same return c def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): @@ -236,9 +231,11 @@ class TransformerOptionsHook(Hook): ''' Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options. ''' - def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None): + def __init__(self, transformers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None, + hook_scope=EnumHookScope.AllConditioning): super().__init__(hook_type=EnumHookType.TransformerOptions) - self.transformers_dict = wrappers_dict + self.transformers_dict = transformers_dict + self.hook_scope = hook_scope def clone(self): c: TransformerOptionsHook = super().clone() @@ -254,8 +251,9 @@ class TransformerOptionsHook(Hook): "to_load_options": self.transformers_dict} else: add_model_options = {"to_load_options": self.transformers_dict} + # only register if will not be included in AllConditioning to avoid double loading + registered.add(self) comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) - registered.add(self) return True def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]): @@ -265,10 +263,12 @@ WrapperHook = TransformerOptionsHook '''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.''' class SetInjectionsHook(Hook): - def __init__(self, key: str=None, injections: list[PatcherInjection]=None): + def __init__(self, key: str=None, injections: list[PatcherInjection]=None, + hook_scope=EnumHookScope.AllConditioning): super().__init__(hook_type=EnumHookType.Injections) self.key = key self.injections = injections + self.hook_scope = hook_scope def clone(self): c: SetInjectionsHook = super().clone() @@ -590,6 +590,17 @@ def get_sorted_list_via_attr(objects: list, attr: str) -> list: sorted_list.extend(object_list) return sorted_list +def create_transformer_options_from_hooks(model: ModelPatcher, hooks: HookGroup, transformer_options: dict[str]=None): + # if no hooks or is not a ModelPatcher for sampling, return empty dict + if hooks is None or model.is_clip: + return {} + if transformer_options is None: + transformer_options = {} + for hook in hooks.get_type(EnumHookType.TransformerOptions): + hook: TransformerOptionsHook + hook.on_apply_hooks(model, transformer_options) + return transformer_options + def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float): hook_group = HookGroup() hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 57a843b8f..51a62e048 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1010,11 +1010,11 @@ class ModelPatcher: def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False): # TODO: return transformer_options dict with any additions from hooks if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)): - return {} + return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options) self.patch_hooks(hooks=hooks) for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS): callback(self, hooks) - return {} + return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options) def patch_hooks(self, hooks: comfy.hooks.HookGroup): with self.use_ejected(): From 1b38f5bf57ca07490e616dd58ec3004d05de0655 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 6 Jan 2025 17:11:12 -0600 Subject: [PATCH 11/52] removed 4 whitespace lines to satisfy Ruff, --- comfy/hooks.py | 4 ++-- comfy/samplers.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index cc9f6cd54..46fc06bdc 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -255,7 +255,7 @@ class TransformerOptionsHook(Hook): registered.add(self) comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) return True - + def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]): comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False) @@ -290,7 +290,7 @@ class SetInjectionsHook(Hook): class HookGroup: ''' Stores groups of hooks, and allows them to be queried by type. - + To prevent breaking their functionality, never modify the underlying self.hooks or self._hook_dict vars directly; always use the provided functions on HookGroup. ''' diff --git a/comfy/samplers.py b/comfy/samplers.py index 43a735c6e..a725d5185 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -855,7 +855,7 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): to_load_options = model_options.get("to_load_options", None) if to_load_options is None: return - + casts = [] if device is not None: casts.append(device) @@ -864,7 +864,7 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): # if nothing to apply, do nothing if len(casts) == 0: return - + # Try to call .to on patches if "patches" in to_load_options: patches = to_load_options["patches"] From 58bf8815c84b67ab26b0f08b8530a822b9899b10 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 6 Jan 2025 20:34:30 -0600 Subject: [PATCH 12/52] Add a get_injections function to ModelPatcher --- comfy/model_patcher.py | 3 +++ comfy/samplers.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 51a62e048..7d7977c14 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -842,6 +842,9 @@ class ModelPatcher: if key in self.injections: self.injections.pop(key) + def get_injections(self, key: str): + return self.injections.get(key, None) + def set_additional_models(self, key: str, models: list['ModelPatcher']): self.additional_models[key] = models diff --git a/comfy/samplers.py b/comfy/samplers.py index a725d5185..5cc33a7d9 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -865,7 +865,7 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): if len(casts) == 0: return - # Try to call .to on patches + # try to call .to on patches if "patches" in to_load_options: patches = to_load_options["patches"] for name in patches: @@ -882,7 +882,7 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): if hasattr(patch_list[k], "to"): for cast in casts: patch_list[k] = patch_list[k].to(cast) - # Try to call .to on any wrappers/callbacks + # try to call .to on any wrappers/callbacks wrappers_and_callbacks = ["wrappers", "callbacks"] for wc_name in wrappers_and_callbacks: if wc_name in to_load_options: From 216fea15ee033d3301241a5ceb0e193b4924de04 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 7 Jan 2025 00:59:18 -0600 Subject: [PATCH 13/52] Made TransformerOptionsHook contribute to registered hooks properly, added some doc strings and removed a so-far unused variable --- comfy/hooks.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index 46fc06bdc..7c2f66892 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -66,7 +66,7 @@ class _HookRef: def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): - '''Example for how should_register function should look like.''' + '''Example for how custom_should_register function can look like.''' return True @@ -83,12 +83,17 @@ class Hook: def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None, hook_keyframe: HookKeyframeGroup=None, hook_scope=EnumHookScope.AllConditioning): self.hook_type = hook_type + '''Enum identifying the general class of this hook.''' self.hook_ref = hook_ref if hook_ref else _HookRef() + '''Reference shared between hook clones that have the same value. Should NOT be modified.''' self.hook_id = hook_id + '''Optional string ID to identify hook; useful if need to consolidate duplicates at registration time.''' self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup() + '''Keyframe storage that can be referenced to get strength for current sampling step.''' self.hook_scope = hook_scope + '''Scope of where this hook should apply in terms of the conds used in sampling run.''' self.custom_should_register = default_should_register - self.auto_apply_to_nonpositive = False + '''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register''' @property def strength(self): @@ -109,8 +114,6 @@ class Hook: c.hook_keyframe = self.hook_keyframe c.hook_scope = self.hook_scope c.custom_should_register = self.custom_should_register - # TODO: make this do something - c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive return c def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): @@ -236,28 +239,34 @@ class TransformerOptionsHook(Hook): super().__init__(hook_type=EnumHookType.TransformerOptions) self.transformers_dict = transformers_dict self.hook_scope = hook_scope + self._skip_adding = False + '''Internal value used to avoid double load of transformer_options when hook_scope is AllConditioning.''' def clone(self): c: TransformerOptionsHook = super().clone() c.transformers_dict = self.transformers_dict + c._skip_adding = self._skip_adding return c def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): if not self.should_register(model, model_options, target_dict, registered): return False # NOTE: to_load_options will be used to manually load patches/wrappers/callbacks from hooks + self._skip_adding = False if self.hook_scope == EnumHookScope.AllConditioning: add_model_options = {"transformer_options": self.transformers_dict, "to_load_options": self.transformers_dict} + # skip_adding if included in AllConditioning to avoid double loading + self._skip_adding = True else: add_model_options = {"to_load_options": self.transformers_dict} - # only register if will not be included in AllConditioning to avoid double loading - registered.add(self) + registered.add(self) comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False) return True def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]): - comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False) + if not self._skip_adding: + comfy.patcher_extension.merge_nested_dicts(transformer_options, self.transformers_dict, copy_dict1=False) WrapperHook = TransformerOptionsHook '''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.''' From 3cd4c5cb0a9d4f4f944ee1382e074d3a41e18874 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 7 Jan 2025 02:22:49 -0600 Subject: [PATCH 14/52] Rename AddModelsHooks to AdditionalModelsHook, rename SetInjectionsHook to InjectionsHook (not yet implemented, but at least getting the naming figured out) --- comfy/hooks.py | 26 +++++++------------------- comfy/sampler_helpers.py | 8 ++++---- 2 files changed, 11 insertions(+), 23 deletions(-) diff --git a/comfy/hooks.py b/comfy/hooks.py index 7c2f66892..9d0731072 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -42,7 +42,7 @@ class EnumHookType(enum.Enum): ''' Weight = "weight" ObjectPatch = "object_patch" - AddModels = "add_models" + AdditionalModels = "add_models" TransformerOptions = "transformer_options" Injections = "add_injections" @@ -202,24 +202,20 @@ class ObjectPatchHook(Hook): def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): raise NotImplementedError("ObjectPatchHook is not supported yet in ComfyUI.") - if not self.should_register(model, model_options, target_dict, registered): - return False - registered.add(self) - return True -class AddModelsHook(Hook): +class AdditionalModelsHook(Hook): ''' Hook responsible for telling model management any additional models that should be loaded. Note, value of hook_scope is ignored and is treated as AllConditioning. ''' def __init__(self, models: list[ModelPatcher]=None, key: str=None): - super().__init__(hook_type=EnumHookType.AddModels) + super().__init__(hook_type=EnumHookType.AdditionalModels) self.models = models self.key = key def clone(self): - c: AddModelsHook = super().clone() + c: AdditionalModelsHook = super().clone() c.models = self.models.copy() if self.models else self.models c.key = self.key return c @@ -271,7 +267,7 @@ class TransformerOptionsHook(Hook): WrapperHook = TransformerOptionsHook '''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.''' -class SetInjectionsHook(Hook): +class InjectionsHook(Hook): def __init__(self, key: str=None, injections: list[PatcherInjection]=None, hook_scope=EnumHookScope.AllConditioning): super().__init__(hook_type=EnumHookType.Injections) @@ -280,21 +276,13 @@ class SetInjectionsHook(Hook): self.hook_scope = hook_scope def clone(self): - c: SetInjectionsHook = super().clone() + c: InjectionsHook = super().clone() c.key = self.key c.injections = self.injections.copy() if self.injections else self.injections return c def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup): - raise NotImplementedError("SetInjectionsHook is not supported yet in ComfyUI.") - if not self.should_register(model, model_options, target_dict, registered): - return False - registered.add(self) - return True - - def add_hook_injections(self, model: ModelPatcher): - # TODO: add functionality - pass + raise NotImplementedError("InjectionsHook is not supported yet in ComfyUI.") class HookGroup: ''' diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 1433d1859..b70e5e636 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -94,8 +94,8 @@ def get_additional_models_from_model_options(model_options: dict[str]=None): models = [] if model_options is not None and "registered_hooks" in model_options: registered: comfy.hooks.HookGroup = model_options["registered_hooks"] - for hook in registered.get_type(comfy.hooks.EnumHookType.AddModels): - hook: comfy.hooks.AddModelsHook + for hook in registered.get_type(comfy.hooks.EnumHookType.AdditionalModels): + hook: comfy.hooks.AdditionalModelsHook models.extend(hook.models) return models @@ -146,8 +146,8 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): hook: comfy.hooks.TransformerOptionsHook hook.add_hook_patches(model, model_options, target_dict, registered) # handle all AddModelsHooks - for hook in hooks.get_type(comfy.hooks.EnumHookType.AddModels): - hook: comfy.hooks.AddModelsHook + for hook in hooks.get_type(comfy.hooks.EnumHookType.AdditionalModels): + hook: comfy.hooks.AdditionalModelsHook hook.add_hook_patches(model, model_options, target_dict, registered) # handle all WeightHooks by registering on ModelPatcher model.register_all_hook_patches(hooks, target_dict, model_options, registered) From 733328169868b9f4120cbfc59af2b00683df8563 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 7 Jan 2025 02:58:59 -0600 Subject: [PATCH 15/52] Clean up a typehint --- comfy_extras/nodes_hooks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_hooks.py b/comfy_extras/nodes_hooks.py index 642238340..1edc06f3d 100644 --- a/comfy_extras/nodes_hooks.py +++ b/comfy_extras/nodes_hooks.py @@ -246,7 +246,7 @@ class SetClipHooks: CATEGORY = "advanced/hooks/clip" FUNCTION = "apply_hooks" - def apply_hooks(self, clip: 'CLIP', schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None): + def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None): if hooks is not None: clip = clip.clone() if apply_to_conds: From 871258aa722fb8031e251c7e4d0ecffa9a11c460 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 7 Jan 2025 21:06:03 -0600 Subject: [PATCH 16/52] Add get_all_torch_devices to get detected devices intended for current torch hardware device --- comfy/model_management.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index f6dfc18b0..003a89f51 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -128,6 +128,19 @@ def get_torch_device(): else: return torch.device(torch.cuda.current_device()) +def get_all_torch_devices(exclude_current=False): + global cpu_state + devices = [] + if cpu_state == CPUState.GPU: + if is_nvidia(): + for i in range(torch.cuda.device_count()): + devices.append(torch.device(i)) + else: + devices.append(get_torch_device()) + if exclude_current: + devices.remove(get_torch_device()) + return devices + def get_total_memory(dev=None, torch_total_too=False): global directml_enabled if dev is None: From 7448f02b7cf0e7acf97fcdc41eda4342d062e549 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 8 Jan 2025 03:33:05 -0600 Subject: [PATCH 17/52] Initial proof of concept of giving splitting cond sampling between multiple GPUs --- comfy/model_management.py | 10 +- comfy/model_patcher.py | 4 + comfy/samplers.py | 188 +++++++++++++++++++++++++++++++++++++- 3 files changed, 198 insertions(+), 4 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 003a89f51..87ad290d0 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -15,6 +15,7 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . """ +from __future__ import annotations import psutil import logging @@ -26,6 +27,11 @@ import platform import weakref import gc +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher + from comfy.model_base import BaseModel + class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram NO_VRAM = 1 #Very low vram: enable all the options to save vram @@ -330,7 +336,7 @@ def module_size(module): return module_mem class LoadedModel: - def __init__(self, model): + def __init__(self, model: ModelPatcher): self._set_model(model) self.device = model.load_device self.real_model = None @@ -338,7 +344,7 @@ class LoadedModel: self.model_finalizer = None self._patcher_finalizer = None - def _set_model(self, model): + def _set_model(self, model: ModelPatcher): self._model = weakref.ref(model) if model.parent is not None: self._parent_model = weakref.ref(model.parent) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 0501f7b38..5465dde62 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -218,6 +218,8 @@ class ModelPatcher: self.is_clip = False self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed + self.is_multigpu_clone = False + if not hasattr(self.model, 'model_loaded_weight_memory'): self.model.model_loaded_weight_memory = 0 @@ -293,6 +295,8 @@ class ModelPatcher: n.is_clip = self.is_clip n.hook_mode = self.hook_mode + n.is_multigpu_clone = self.is_multigpu_clone + for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE): callback(self, n) return n diff --git a/comfy/samplers.py b/comfy/samplers.py index 5cc33a7d9..f30640006 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -19,6 +19,7 @@ import comfy.patcher_extension import comfy.hooks import scipy.stats import numpy +import threading def get_area_and_mult(conds, x_in, timestep_in): dims = tuple(x_in.shape[2:]) @@ -130,7 +131,7 @@ def can_concat_cond(c1, c2): return cond_equal_size(c1.conditioning, c2.conditioning) -def cond_cat(c_list): +def cond_cat(c_list, device=None): temp = {} for x in c_list: for k in x: @@ -142,6 +143,8 @@ def cond_cat(c_list): for k in temp: conds = temp[k] out[k] = conds[0].concat(conds[1:]) + if device is not None: + out[k] = out[k].to(device) return out @@ -195,7 +198,9 @@ def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Ten ) return executor.execute(model, conds, x_in, timestep, model_options) -def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): +def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): + if 'multigpu_clones' in model_options: + return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options) out_conds = [] out_counts = [] # separate conds by matching hooks @@ -329,6 +334,173 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te return out_conds +def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): + out_conds = [] + out_counts = [] + # separate conds by matching hooks + hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {} + default_conds = [] + has_default_conds = False + + output_device = x_in.device + + for i in range(len(conds)): + out_conds.append(torch.zeros_like(x_in)) + out_counts.append(torch.ones_like(x_in) * 1e-37) + + cond = conds[i] + default_c = [] + if cond is not None: + for x in cond: + if 'default' in x: + default_c.append(x) + has_default_conds = True + continue + p = comfy.samplers.get_area_and_mult(x, x_in, timestep) + if p is None: + continue + if p.hooks is not None: + model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options) + hooked_to_run.setdefault(p.hooks, list()) + hooked_to_run[p.hooks] += [(p, i)] + default_conds.append(default_c) + + if has_default_conds: + finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options) + + model.current_patcher.prepare_state(timestep) + + devices = [dev_m for dev_m in model_options["multigpu_clones"].keys()] + device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {} + count = 0 + # run every hooked_to_run separately + for hooks, to_run in hooked_to_run.items(): + while len(to_run) > 0: + first = to_run[0] + first_shape = first[0][0].shape + to_batch_temp = [] + for x in range(len(to_run)): + if can_concat_cond(to_run[x][0], first[0]): + to_batch_temp += [x] + + to_batch_temp.reverse() + to_batch = to_batch_temp[:1] + + current_device = devices[count % len(devices)] + free_memory = model_management.get_free_memory(current_device) + for i in range(1, len(to_batch_temp) + 1): + batch_amount = to_batch_temp[:len(to_batch_temp)//i] + input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] + # if model.memory_required(input_shape) * 1.5 < free_memory: + # to_batch = batch_amount + # break + conds_to_batch = [] + for x in to_batch: + conds_to_batch.append(to_run.pop(x)) + + batched_to_run = device_batched_hooked_to_run.setdefault(current_device, []) + batched_to_run.append((hooks, conds_to_batch)) + count += 1 + + thread_result = collections.namedtuple('thread_result', ['output', 'mult', 'area', 'batch_chunks', 'cond_or_uncond']) + def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]): + model_current: BaseModel = model_options["multigpu_clones"][device].model + # run every hooked_to_run separately + with torch.no_grad(): + for hooks, to_batch in batch_tuple: + input_x = [] + mult = [] + c = [] + cond_or_uncond = [] + uuids = [] + area = [] + control = None + patches = None + for x in to_batch: + o = x + p = o[0] + input_x.append(p.input_x) + mult.append(p.mult) + c.append(p.conditioning) + area.append(p.area) + cond_or_uncond.append(o[1]) + uuids.append(p.uuid) + control = p.control + patches = p.patches + + batch_chunks = len(cond_or_uncond) + input_x = torch.cat(input_x).to(device) + c = cond_cat(c, device=device) + timestep_ = torch.cat([timestep.to(device)] * batch_chunks) + + transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks) + if 'transformer_options' in model_options: + transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options, + model_options['transformer_options'], + copy_dict1=False) + + if patches is not None: + # TODO: replace with merge_nested_dicts function + if "patches" in transformer_options: + cur_patches = transformer_options["patches"].copy() + for p in patches: + if p in cur_patches: + cur_patches[p] = cur_patches[p] + patches[p] + else: + cur_patches[p] = patches[p] + transformer_options["patches"] = cur_patches + else: + transformer_options["patches"] = patches + + transformer_options["cond_or_uncond"] = cond_or_uncond[:] + transformer_options["uuids"] = uuids[:] + transformer_options["sigmas"] = timestep + transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device) + + c['transformer_options'] = transformer_options + + if control is not None: + c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options) + + if 'model_function_wrapper' in model_options: + output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks) + else: + output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks) + results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond)) + + + results: list[thread_result] = [] + threads: list[threading.Thread] = [] + for device, batch_tuple in device_batched_hooked_to_run.items(): + new_thread = threading.Thread(target=_handle_batch, args=(device, batch_tuple, results)) + threads.append(new_thread) + new_thread.start() + + for thread in threads: + thread.join() + + for output, mult, area, batch_chunks, cond_or_uncond in results: + for o in range(batch_chunks): + cond_index = cond_or_uncond[o] + a = area[o] + if a is None: + out_conds[cond_index] += output[o] * mult[o] + out_counts[cond_index] += mult[o] + else: + out_c = out_conds[cond_index] + out_cts = out_counts[cond_index] + dims = len(a) // 2 + for i in range(dims): + out_c = out_c.narrow(i + 2, a[i + dims], a[i]) + out_cts = out_cts.narrow(i + 2, a[i + dims], a[i]) + out_c += output[o] * mult[o] + out_cts += mult[o] + + for i in range(len(out_conds)): + out_conds[i] /= out_counts[i] + + return out_conds + def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.") return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options)) @@ -940,6 +1112,14 @@ class CFGGuider: self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device + multigpu_patchers: list[ModelPatcher] = [x for x in self.loaded_models if x.is_multigpu_clone] + if len(multigpu_patchers) > 0: + multigpu_dict: dict[torch.device, ModelPatcher] = {} + multigpu_dict[device] = self.model_patcher + for x in multigpu_patchers: + multigpu_dict[x.load_device] = x + self.model_options["multigpu_clones"] = multigpu_dict + if denoise_mask is not None: denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device) @@ -950,9 +1130,13 @@ class CFGGuider: try: self.model_patcher.pre_run() + for multigpu_patcher in multigpu_patchers: + multigpu_patcher.pre_run() output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) finally: self.model_patcher.cleanup() + for multigpu_patcher in multigpu_patchers: + multigpu_patcher.cleanup() comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models) del self.inner_model From e88c6c03ff16c197e7b49b7908f91a67f21ef7b1 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 10 Jan 2025 23:05:24 -0600 Subject: [PATCH 18/52] Fix cond_cat to not try to cast anything that doesn't have a 'to' function --- comfy/samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index f30640006..98b1932f7 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -143,7 +143,7 @@ def cond_cat(c_list, device=None): for k in temp: conds = temp[k] out[k] = conds[0].concat(conds[1:]) - if device is not None: + if device is not None and hasattr(out[k], 'to'): out[k] = out[k].to(device) return out From d5088072fb7561e6c6b44693c65e31c254c81b81 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 13 Jan 2025 20:20:25 -0600 Subject: [PATCH 19/52] Make test node for multigpu instead of storing it in just a local __init__.py --- comfy_extras/nodes_multigpu.py | 39 ++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 comfy_extras/nodes_multigpu.py diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py new file mode 100644 index 000000000..929151b50 --- /dev/null +++ b/comfy_extras/nodes_multigpu.py @@ -0,0 +1,39 @@ +from comfy.model_patcher import ModelPatcher +import comfy.utils +import comfy.patcher_extension +import comfy.model_management +import copy + + +class MultiGPUInitialize: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model": ("MODEL",), + } + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "init_multigpu" + CATEGORY = "DevTools" + + def init_multigpu(self, model: ModelPatcher): + model = model.clone() + extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) + if len(extra_devices) > 0: + comfy.model_management.unload_all_models() + for device in extra_devices: + device_patcher = model.clone() + device_patcher.model = copy.deepcopy(model.model) + device_patcher.load_device = device + device_patcher.is_multigpu_clone = True + multigpu_models = model.get_additional_models_with_key("multigpu") + multigpu_models.append(device_patcher) + model.set_additional_models("multigpu", multigpu_models) + return (model,) + + +NODE_CLASS_MAPPINGS = { + "test_multigpuinit": MultiGPUInitialize, +} \ No newline at end of file From 198953cd088b8a02701315e59047db16a6e6438a Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 14 Jan 2025 12:24:55 -0600 Subject: [PATCH 20/52] Add nodes_multigpu.py to loaded nodes --- nodes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nodes.py b/nodes.py index cfd7dd8a4..62b6ad18a 100644 --- a/nodes.py +++ b/nodes.py @@ -2224,6 +2224,7 @@ def init_builtin_extra_nodes(): "nodes_mahiro.py", "nodes_lt.py", "nodes_hooks.py", + "nodes_multigpu.py", "nodes_load_3d.py", "nodes_cosmos.py", ] From 25818dc848f8db6f79b4410e46b06133165d35a2 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 14 Jan 2025 13:45:14 -0600 Subject: [PATCH 21/52] Added a 'max_gpus' input --- comfy_extras/nodes_multigpu.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 929151b50..3ba558621 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -11,6 +11,9 @@ class MultiGPUInitialize: return { "required": { "model": ("MODEL",), + }, + "optional": { + "max_gpus" : ("INT", {"default": 8, "min": 1, "step": 1}), } } @@ -18,9 +21,10 @@ class MultiGPUInitialize: FUNCTION = "init_multigpu" CATEGORY = "DevTools" - def init_multigpu(self, model: ModelPatcher): + def init_multigpu(self, model: ModelPatcher, max_gpus: int): model = model.clone() extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) + extra_devices = extra_devices[:max_gpus-1] if len(extra_devices) > 0: comfy.model_management.unload_all_models() for device in extra_devices: From bfce72331188c4efdfe41edcf4e941ac328632cb Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 17 Jan 2025 03:31:28 -0600 Subject: [PATCH 22/52] Initial work on multigpu_clone function, which will account for additional_models getting cloned --- comfy/model_patcher.py | 25 +++++++++++++++++++++++++ comfy_extras/nodes_multigpu.py | 6 ++---- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 5465dde62..63f1f92e4 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -219,6 +219,7 @@ class ModelPatcher: self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed self.is_multigpu_clone = False + self.clone_uuid = uuid.uuid4() if not hasattr(self.model, 'model_loaded_weight_memory'): self.model.model_loaded_weight_memory = 0 @@ -296,11 +297,35 @@ class ModelPatcher: n.hook_mode = self.hook_mode n.is_multigpu_clone = self.is_multigpu_clone + n.clone_uuid = self.clone_uuid for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE): callback(self, n) return n + def multigpu_clone(self, new_load_device=None, models_cache: dict[ModelPatcher,ModelPatcher]=None): + n = self.clone() + # set load device, if present + if new_load_device is not None: + n.load_device = new_load_device + # unlike for normal clone, backup dicts that shared same ref should not; + # otherwise, patchers that have deep copies of base models will erroneously influence each other. + n.backup = copy.deepcopy(n.backup) + n.object_patches_backup = copy.deepcopy(n.object_patches_backup) + n.model = copy.deepcopy(n.model) + # multigpu clone should not have multigpu additional_models entry + n.remove_additional_models("multigpu") + # multigpu_clone all stored additional_models; make sure circular references are properly handled + if models_cache is None: + models_cache = {} + for key, model_list in n.additional_models.items(): + for i in range(len(model_list)): + add_model = n.additional_models[key][i] + if i not in models_cache: + models_cache[add_model] = add_model.multigpu_clone(new_load_device=new_load_device, models_cache=models_cache) + n.additional_models[key][i] = models_cache[add_model] + return n + def is_clone(self, other): if hasattr(other, 'model') and self.model is other.model: return True diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 3ba558621..dec395fb3 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -22,15 +22,13 @@ class MultiGPUInitialize: CATEGORY = "DevTools" def init_multigpu(self, model: ModelPatcher, max_gpus: int): - model = model.clone() extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) extra_devices = extra_devices[:max_gpus-1] if len(extra_devices) > 0: + model = model.clone() comfy.model_management.unload_all_models() for device in extra_devices: - device_patcher = model.clone() - device_patcher.model = copy.deepcopy(model.model) - device_patcher.load_device = device + device_patcher = model.multigpu_clone(new_load_device=device) device_patcher.is_multigpu_clone = True multigpu_models = model.get_additional_models_with_key("multigpu") multigpu_models.append(device_patcher) From 328d4f16a90f5d5d8ac1218bcc5b0862fe970afb Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 20 Jan 2025 04:34:26 -0600 Subject: [PATCH 23/52] Make WeightHooks compatible with MultiGPU, clean up some code --- comfy/model_patcher.py | 46 ++++++++++++++++++++++++++++++---- comfy/sampler_helpers.py | 18 ++++++++++++- comfy/samplers.py | 46 ++++++++++++++++++++-------------- comfy_extras/nodes_multigpu.py | 2 +- 4 files changed, 86 insertions(+), 26 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 63f1f92e4..46779397e 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -84,12 +84,15 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_ def create_model_options_clone(orig_model_options: dict): return comfy.patcher_extension.copy_nested_dicts(orig_model_options) -def create_hook_patches_clone(orig_hook_patches): +def create_hook_patches_clone(orig_hook_patches, copy_tuples=False): new_hook_patches = {} for hook_ref in orig_hook_patches: new_hook_patches[hook_ref] = {} for k in orig_hook_patches[hook_ref]: new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:] + if copy_tuples: + for i in range(len(new_hook_patches[hook_ref][k])): + new_hook_patches[hook_ref][k][i] = tuple(new_hook_patches[hook_ref][k][i]) return new_hook_patches def wipe_lowvram_weight(m): @@ -303,7 +306,7 @@ class ModelPatcher: callback(self, n) return n - def multigpu_clone(self, new_load_device=None, models_cache: dict[ModelPatcher,ModelPatcher]=None): + def multigpu_deepclone(self, new_load_device=None, models_cache: dict[ModelPatcher,ModelPatcher]=None): n = self.clone() # set load device, if present if new_load_device is not None: @@ -312,6 +315,7 @@ class ModelPatcher: # otherwise, patchers that have deep copies of base models will erroneously influence each other. n.backup = copy.deepcopy(n.backup) n.object_patches_backup = copy.deepcopy(n.object_patches_backup) + n.hook_backup = copy.deepcopy(n.hook_backup) n.model = copy.deepcopy(n.model) # multigpu clone should not have multigpu additional_models entry n.remove_additional_models("multigpu") @@ -322,7 +326,7 @@ class ModelPatcher: for i in range(len(model_list)): add_model = n.additional_models[key][i] if i not in models_cache: - models_cache[add_model] = add_model.multigpu_clone(new_load_device=new_load_device, models_cache=models_cache) + models_cache[add_model] = add_model.multigpu_deepclone(new_load_device=new_load_device, models_cache=models_cache) n.additional_models[key][i] = models_cache[add_model] return n @@ -952,9 +956,13 @@ class ModelPatcher: for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN): callback(self) - def prepare_state(self, timestep): + def prepare_state(self, timestep, model_options, ignore_multigpu=False): for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE): - callback(self, timestep) + callback(self, timestep, model_options, ignore_multigpu) + if not ignore_multigpu and "multigpu_clones" in model_options: + for p in model_options["multigpu_clones"].values(): + p: ModelPatcher + p.prepare_state(timestep, model_options, ignore_multigpu=True) def restore_hook_patches(self): if self.hook_patches_backup is not None: @@ -967,12 +975,18 @@ class ModelPatcher: def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]): curr_t = t[0] reset_current_hooks = False + multigpu_kf_changed_cache = None transformer_options = model_options.get("transformer_options", {}) for hook in hook_group.hooks: changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options) # if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref; # this will cause the weights to be recalculated when sampling if changed: + # cache changed for multigpu usage + if "multigpu_clones" in model_options: + if multigpu_kf_changed_cache is None: + multigpu_kf_changed_cache = [] + multigpu_kf_changed_cache.append(hook) # reset current_hooks if contains hook that changed if self.current_hooks is not None: for current_hook in self.current_hooks.hooks: @@ -984,6 +998,28 @@ class ModelPatcher: self.cached_hook_patches.pop(cached_group) if reset_current_hooks: self.patch_hooks(None) + if "multigpu_clones" in model_options: + for p in model_options["multigpu_clones"].values(): + p: ModelPatcher + p._handle_changed_hook_keyframes(multigpu_kf_changed_cache) + + def _handle_changed_hook_keyframes(self, kf_changed_cache: list[comfy.hooks.Hook]): + 'Used to handle multigpu behavior inside prepare_hook_patches_current_keyframe.' + if kf_changed_cache is None: + return + reset_current_hooks = False + # reset current_hooks if contains hook that changed + for hook in kf_changed_cache: + if self.current_hooks is not None: + for current_hook in self.current_hooks.hooks: + if current_hook == hook: + reset_current_hooks = True + break + for cached_group in list(self.cached_hook_patches.keys()): + if cached_group.contains(hook): + self.cached_hook_patches.pop(cached_group) + if reset_current_hooks: + self.patch_hooks(None) def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None, registered: comfy.hooks.HookGroup = None): diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index b70e5e636..a95231ff5 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -1,7 +1,9 @@ from __future__ import annotations +import torch import uuid import comfy.model_management import comfy.conds +import comfy.model_patcher import comfy.utils import comfy.hooks import comfy.patcher_extension @@ -127,7 +129,7 @@ def cleanup_models(conds, models): cleanup_additional_models(set(control_cleanup)) -def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): +def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict): ''' Registers hooks from conds. ''' @@ -160,3 +162,17 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name], copy_dict1=False) return to_load_options + +def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_models: list[ModelPatcher], model_options: dict): + ''' + In case multigpu acceleration is enabled, prep ModelPatchers for each device. + ''' + multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_clone] + if len(multigpu_patchers) > 0: + multigpu_dict: dict[torch.device, ModelPatcher] = {} + multigpu_dict[model_patcher.load_device] = model_patcher + for x in multigpu_patchers: + x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True) + multigpu_dict[x.load_device] = x + model_options["multigpu_clones"] = multigpu_dict + return multigpu_patchers diff --git a/comfy/samplers.py b/comfy/samplers.py index e9cd076e9..dde0b6521 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -232,7 +232,7 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te if has_default_conds: finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options) - model.current_patcher.prepare_state(timestep) + model.current_patcher.prepare_state(timestep, model_options) # run every hooked_to_run separately for hooks, to_run in hooked_to_run.items(): @@ -368,39 +368,53 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t if has_default_conds: finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options) - model.current_patcher.prepare_state(timestep) + model.current_patcher.prepare_state(timestep, model_options) - devices = [dev_m for dev_m in model_options["multigpu_clones"].keys()] + devices = [dev_m for dev_m in model_options['multigpu_clones'].keys()] device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {} - count = 0 + + total_conds = 0 + for to_run in hooked_to_run.values(): + total_conds += len(to_run) + conds_per_device = max(1, math.ceil(total_conds//len(devices))) + index_device = 0 + current_device = devices[index_device] # run every hooked_to_run separately for hooks, to_run in hooked_to_run.items(): while len(to_run) > 0: + current_device = devices[index_device % len(devices)] + batched_to_run = device_batched_hooked_to_run.setdefault(current_device, []) + # keep track of conds currently scheduled onto this device + batched_to_run_length = 0 + for btr in batched_to_run: + batched_to_run_length += len(btr[1]) + first = to_run[0] first_shape = first[0][0].shape to_batch_temp = [] + # make sure not over conds_per_device limit when creating temp batch for x in range(len(to_run)): - if can_concat_cond(to_run[x][0], first[0]): + if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < (conds_per_device - batched_to_run_length): to_batch_temp += [x] to_batch_temp.reverse() to_batch = to_batch_temp[:1] - current_device = devices[count % len(devices)] free_memory = model_management.get_free_memory(current_device) for i in range(1, len(to_batch_temp) + 1): batch_amount = to_batch_temp[:len(to_batch_temp)//i] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] - # if model.memory_required(input_shape) * 1.5 < free_memory: - # to_batch = batch_amount - # break + if model.memory_required(input_shape) * 1.5 < free_memory: + to_batch = batch_amount + break conds_to_batch = [] for x in to_batch: conds_to_batch.append(to_run.pop(x)) - - batched_to_run = device_batched_hooked_to_run.setdefault(current_device, []) + batched_to_run_length += len(conds_to_batch) + batched_to_run.append((hooks, conds_to_batch)) - count += 1 + if batched_to_run_length >= conds_per_device: + index_device += 1 thread_result = collections.namedtuple('thread_result', ['output', 'mult', 'area', 'batch_chunks', 'cond_or_uncond']) def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]): @@ -1112,13 +1126,7 @@ class CFGGuider: self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device - multigpu_patchers: list[ModelPatcher] = [x for x in self.loaded_models if x.is_multigpu_clone] - if len(multigpu_patchers) > 0: - multigpu_dict: dict[torch.device, ModelPatcher] = {} - multigpu_dict[device] = self.model_patcher - for x in multigpu_patchers: - multigpu_dict[x.load_device] = x - self.model_options["multigpu_clones"] = multigpu_dict + multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options) if denoise_mask is not None: denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device) diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index dec395fb3..b3c8635b8 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -28,7 +28,7 @@ class MultiGPUInitialize: model = model.clone() comfy.model_management.unload_all_models() for device in extra_devices: - device_patcher = model.multigpu_clone(new_load_device=device) + device_patcher = model.multigpu_deepclone(new_load_device=device) device_patcher.is_multigpu_clone = True multigpu_models = model.get_additional_models_with_key("multigpu") multigpu_models.append(device_patcher) From 02a4d0ad7de479c8e1145b2305edab4bac2a2e45 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 23 Jan 2025 01:20:00 -0600 Subject: [PATCH 24/52] Added unload_model_and_clones to model_management.py to allow unloading only relevant models --- comfy/model_management.py | 10 ++++++++++ comfy/model_patcher.py | 4 ++-- comfy/sampler_helpers.py | 1 + comfy_extras/nodes_multigpu.py | 2 +- 4 files changed, 14 insertions(+), 3 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 87ad290d0..2cf792b56 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1146,6 +1146,16 @@ def soft_empty_cache(force=False): def unload_all_models(): free_memory(1e30, get_torch_device()) +def unload_model_and_clones(model: ModelPatcher): + 'Unload only model and its clones - primarily for multigpu cloning purposes.' + initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy() + keep_loaded = [] + for loaded_model in initial_keep_loaded: + if loaded_model.model is not None: + if model.clone_base_uuid == loaded_model.model.clone_base_uuid: + continue + keep_loaded.append(loaded_model) + free_memory(1e30, get_torch_device(), keep_loaded) #TODO: might be cleaner to put this somewhere else import threading diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 46779397e..b4efa8d02 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -222,7 +222,7 @@ class ModelPatcher: self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed self.is_multigpu_clone = False - self.clone_uuid = uuid.uuid4() + self.clone_base_uuid = uuid.uuid4() if not hasattr(self.model, 'model_loaded_weight_memory'): self.model.model_loaded_weight_memory = 0 @@ -300,7 +300,7 @@ class ModelPatcher: n.hook_mode = self.hook_mode n.is_multigpu_clone = self.is_multigpu_clone - n.clone_uuid = self.clone_uuid + n.clone_base_uuid = self.clone_base_uuid for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE): callback(self, n) diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index a95231ff5..5564b62c2 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -173,6 +173,7 @@ def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_mo multigpu_dict[model_patcher.load_device] = model_patcher for x in multigpu_patchers: x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True) + x.hook_mode = model_patcher.hook_mode # match main model's hook_mode multigpu_dict[x.load_device] = x model_options["multigpu_clones"] = multigpu_dict return multigpu_patchers diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index b3c8635b8..b5c36c64d 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -26,7 +26,7 @@ class MultiGPUInitialize: extra_devices = extra_devices[:max_gpus-1] if len(extra_devices) > 0: model = model.clone() - comfy.model_management.unload_all_models() + comfy.model_management.unload_model_and_clones(model) for device in extra_devices: device_patcher = model.multigpu_deepclone(new_load_device=device) device_patcher.is_multigpu_clone = True From 5db42774496189142ef1521d7280e47b14044de5 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 23 Jan 2025 19:06:05 -0600 Subject: [PATCH 25/52] Make sure additional_models are unloaded as well when perform --- comfy/model_management.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 2cf792b56..c72ed247d 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1146,14 +1146,25 @@ def soft_empty_cache(force=False): def unload_all_models(): free_memory(1e30, get_torch_device()) -def unload_model_and_clones(model: ModelPatcher): +def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True): 'Unload only model and its clones - primarily for multigpu cloning purposes.' initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy() + additional_models = [] + if unload_additional_models: + additional_models = model.get_nested_additional_models() keep_loaded = [] for loaded_model in initial_keep_loaded: if loaded_model.model is not None: if model.clone_base_uuid == loaded_model.model.clone_base_uuid: continue + # check additional models if they are a match + skip = False + for add_model in additional_models: + if add_model.clone_base_uuid == loaded_model.model.clone_base_uuid: + skip = True + break + if skip: + continue keep_loaded.append(loaded_model) free_memory(1e30, get_torch_device(), keep_loaded) From 46969c380aa15dd7f26dfcee67dc59b9d830a3bd Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 24 Jan 2025 05:39:38 -0600 Subject: [PATCH 26/52] Initial MultiGPU support for controlnets --- comfy/controlnet.py | 49 ++++++++++++++++++++++++++++++++++++++---- comfy/samplers.py | 52 +++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 95 insertions(+), 6 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index ee29251b9..0029a4987 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -15,13 +15,14 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . """ - +from __future__ import annotations import torch from enum import Enum import math import os import logging +import copy import comfy.utils import comfy.model_management import comfy.model_detection @@ -36,7 +37,7 @@ import comfy.cldm.mmdit import comfy.ldm.hydit.controlnet import comfy.ldm.flux.controlnet import comfy.cldm.dit_embedder -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union if TYPE_CHECKING: from comfy.hooks import HookGroup @@ -76,7 +77,7 @@ class ControlBase: self.compression_ratio = 8 self.upscale_algorithm = 'nearest-exact' self.extra_args = {} - self.previous_controlnet = None + self.previous_controlnet: Union[ControlBase, None] = None self.extra_conds = [] self.strength_type = StrengthType.CONSTANT self.concat_mask = False @@ -84,6 +85,7 @@ class ControlBase: self.extra_concat = None self.extra_hooks: HookGroup = None self.preprocess_image = lambda a: a + self.multigpu_clones: dict[torch.device, ControlBase] = {} def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]): self.cond_hint_original = cond_hint @@ -117,10 +119,33 @@ class ControlBase: def get_models(self): out = [] + for device_cnet in self.multigpu_clones.values(): + out += device_cnet.get_models() if self.previous_controlnet is not None: out += self.previous_controlnet.get_models() return out + def get_models_only_self(self): + 'Calls get_models, but temporarily sets previous_controlnet to None.' + try: + orig_previous_controlnet = self.previous_controlnet + self.previous_controlnet = None + return self.get_models() + finally: + self.previous_controlnet = orig_previous_controlnet + + def get_instance_for_device(self, device): + 'Returns instance of this Control object intended for selected device.' + return self.multigpu_clones.get(device, self) + + def deepclone_multigpu(self, load_device, autoregister=False): + ''' + Create deep clone of Control object where model(s) is set to other devices. + + When autoregister is set to True, the deep clone is also added to multigpu_clones dict. + ''' + raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.") + def get_extra_hooks(self): out = [] if self.extra_hooks is not None: @@ -129,7 +154,7 @@ class ControlBase: out += self.previous_controlnet.get_extra_hooks() return out - def copy_to(self, c): + def copy_to(self, c: ControlBase): c.cond_hint_original = self.cond_hint_original c.strength = self.strength c.timestep_percent_range = self.timestep_percent_range @@ -280,6 +305,14 @@ class ControlNet(ControlBase): self.copy_to(c) return c + def deepclone_multigpu(self, load_device, autoregister=False): + c = self.copy() + c.control_model = copy.deepcopy(c.control_model) + c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) + if autoregister: + self.multigpu_clones[load_device] = c + return c + def get_models(self): out = super().get_models() out.append(self.control_model_wrapped) @@ -809,6 +842,14 @@ class T2IAdapter(ControlBase): c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm) self.copy_to(c) return c + + def deepclone_multigpu(self, load_device, autoregister=False): + c = self.copy() + c.t2i_model = copy.deepcopy(c.t2i_model) + c.device = load_device + if autoregister: + self.multigpu_clones[load_device] = c + return c def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options compression_ratio = 8 diff --git a/comfy/samplers.py b/comfy/samplers.py index cf97b9820..27d875709 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1,4 +1,6 @@ from __future__ import annotations + +import comfy.model_management from .k_diffusion import sampling as k_diffusion_sampling from .extra_samplers import uni_pc from typing import TYPE_CHECKING, Callable, NamedTuple @@ -427,7 +429,7 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t cond_or_uncond = [] uuids = [] area = [] - control = None + control: ControlBase = None patches = None for x in to_batch: o = x @@ -473,7 +475,8 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t c['transformer_options'] = transformer_options if control is not None: - c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options) + device_control = control.get_instance_for_device(device) + c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options) if 'model_function_wrapper' in model_options: output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks) @@ -799,6 +802,8 @@ def pre_run_control(model, conds): percent_to_timestep_function = lambda a: s.percent_to_sigma(a) if 'control' in x: x['control'].pre_run(model, percent_to_timestep_function) + for device_cnet in x['control'].multigpu_clones.values(): + device_cnet.pre_run(model, percent_to_timestep_function) def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): cond_cnets = [] @@ -1080,6 +1085,48 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): wc_list[i] = wc_list[i].to(cast) +def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model_options: dict[str], model: ModelPatcher): + '''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.''' + multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu") + if len(multigpu_models) == 0: + return + extra_devices = [x.load_device for x in multigpu_models] + # handle controlnets + controlnets: set[ControlBase] = set() + for k in conds: + for kk in conds[k]: + if 'control' in kk: + controlnets.add(kk['control']) + if len(controlnets) > 0: + # first, unload all controlnet clones + for cnet in list(controlnets): + cnet_models = cnet.get_models() + for cm in cnet_models: + comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True) + + # next, make sure each controlnet has a deepclone for all relevant devices + for cnet in controlnets: + curr_cnet = cnet + while curr_cnet is not None: + for device in extra_devices: + if device not in curr_cnet.multigpu_clones: + curr_cnet.deepclone_multigpu(device, autoregister=True) + curr_cnet = curr_cnet.previous_controlnet + # since all device clones are now present, recreate the linked list for cloned cnets per device + for cnet in controlnets: + curr_cnet = cnet + while curr_cnet is not None: + prev_cnet = curr_cnet.previous_controlnet + for device in extra_devices: + device_cnet = curr_cnet.get_instance_for_device(device) + prev_device_cnet = None + if prev_cnet is not None: + prev_device_cnet = prev_cnet.get_instance_for_device(device) + device_cnet.set_previous_controlnet(prev_device_cnet) + curr_cnet = prev_cnet + # TODO: handle gligen + + class CFGGuider: def __init__(self, model_patcher: ModelPatcher): self.model_patcher = model_patcher @@ -1122,6 +1169,7 @@ class CFGGuider: return self.inner_model.process_latent_out(samples.to(torch.float32)) def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): + preprocess_multigpu_conds(self.conds, self.model_options, self.model_patcher) self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device From 51af7fa1b4f42c674d755e60bfb9a67410f956b4 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 25 Jan 2025 06:05:01 -0600 Subject: [PATCH 27/52] Fix multigpu ControlBase get_models and cleanup calls to avoid multiple calls of functions on multigpu_clones versions of controlnets --- comfy/controlnet.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 0029a4987..31227ae31 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -64,6 +64,18 @@ class StrengthType(Enum): CONSTANT = 1 LINEAR_UP = 2 +class ControlIsolation: + '''Temporarily set a ControlBase object's previous_controlnet to None to prevent cascading calls.''' + def __init__(self, control: ControlBase): + self.control = control + self.orig_previous_controlnet = control.previous_controlnet + + def __enter__(self): + self.control.previous_controlnet = None + + def __exit__(self, *args): + self.control.previous_controlnet = self.orig_previous_controlnet + class ControlBase: def __init__(self): self.cond_hint_original = None @@ -112,7 +124,9 @@ class ControlBase: def cleanup(self): if self.previous_controlnet is not None: self.previous_controlnet.cleanup() - + for device_cnet in self.multigpu_clones.values(): + with ControlIsolation(device_cnet): + device_cnet.cleanup() self.cond_hint = None self.extra_concat = None self.timestep_range = None @@ -120,19 +134,15 @@ class ControlBase: def get_models(self): out = [] for device_cnet in self.multigpu_clones.values(): - out += device_cnet.get_models() + out += device_cnet.get_models_only_self() if self.previous_controlnet is not None: out += self.previous_controlnet.get_models() return out def get_models_only_self(self): 'Calls get_models, but temporarily sets previous_controlnet to None.' - try: - orig_previous_controlnet = self.previous_controlnet - self.previous_controlnet = None + with ControlIsolation(self): return self.get_models() - finally: - self.previous_controlnet = orig_previous_controlnet def get_instance_for_device(self, device): 'Returns instance of this Control object intended for selected device.' From c7feef90605801fbda28ae473c46008f7b5b404b Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 26 Jan 2025 05:29:27 -0600 Subject: [PATCH 28/52] Cast transformer_options for multigpu --- comfy/samplers.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 27d875709..b8b30f2c6 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -471,7 +471,9 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t transformer_options["uuids"] = uuids[:] transformer_options["sigmas"] = timestep transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device) + transformer_options["multigpu_thread_device"] = device + cast_transformer_options(transformer_options, device=device) c['transformer_options'] = transformer_options if control is not None: @@ -1045,7 +1047,9 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): to_load_options = model_options.get("to_load_options", None) if to_load_options is None: return + cast_transformer_options(to_load_options, device, dtype) +def cast_transformer_options(transformer_options: dict[str], device=None, dtype=None): casts = [] if device is not None: casts.append(device) @@ -1054,18 +1058,17 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): # if nothing to apply, do nothing if len(casts) == 0: return - # try to call .to on patches - if "patches" in to_load_options: - patches = to_load_options["patches"] + if "patches" in transformer_options: + patches = transformer_options["patches"] for name in patches: patch_list = patches[name] for i in range(len(patch_list)): if hasattr(patch_list[i], "to"): for cast in casts: patch_list[i] = patch_list[i].to(cast) - if "patches_replace" in to_load_options: - patches = to_load_options["patches_replace"] + if "patches_replace" in transformer_options: + patches = transformer_options["patches_replace"] for name in patches: patch_list = patches[name] for k in patch_list: @@ -1075,8 +1078,8 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): # try to call .to on any wrappers/callbacks wrappers_and_callbacks = ["wrappers", "callbacks"] for wc_name in wrappers_and_callbacks: - if wc_name in to_load_options: - wc: dict[str, list] = to_load_options[wc_name] + if wc_name in transformer_options: + wc: dict[str, list] = transformer_options[wc_name] for wc_dict in wc.values(): for wc_list in wc_dict.values(): for i in range(len(wc_list)): From e3298b84de502a9df8a20ed2ab2877a30d631ff7 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 26 Jan 2025 09:34:20 -0600 Subject: [PATCH 29/52] Create proper MultiGPU Initialize node, create gpu_options to create scaffolding for asymmetrical GPU support --- comfy_extras/nodes_multigpu.py | 109 ++++++++++++++++++++++++++++++--- 1 file changed, 101 insertions(+), 8 deletions(-) diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index b5c36c64d..2ec1e3cfa 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -1,27 +1,32 @@ +from __future__ import annotations +import torch + from comfy.model_patcher import ModelPatcher import comfy.utils import comfy.patcher_extension import comfy.model_management -import copy class MultiGPUInitialize: + NodeId = "MultiGPU_Initialize" + NodeName = "MultiGPU Initialize" @classmethod def INPUT_TYPES(cls): return { "required": { "model": ("MODEL",), + "max_gpus" : ("INT", {"default": 8, "min": 1, "step": 1}), }, "optional": { - "max_gpus" : ("INT", {"default": 8, "min": 1, "step": 1}), + "gpu_options": ("GPU_OPTIONS",) } } RETURN_TYPES = ("MODEL",) FUNCTION = "init_multigpu" - CATEGORY = "DevTools" + CATEGORY = "advanced/multigpu" - def init_multigpu(self, model: ModelPatcher, max_gpus: int): + def init_multigpu(self, model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None): extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) extra_devices = extra_devices[:max_gpus-1] if len(extra_devices) > 0: @@ -33,9 +38,97 @@ class MultiGPUInitialize: multigpu_models = model.get_additional_models_with_key("multigpu") multigpu_models.append(device_patcher) model.set_additional_models("multigpu", multigpu_models) + if gpu_options is None: + gpu_options = GPUOptionsGroup() + gpu_options.register(model) return (model,) - -NODE_CLASS_MAPPINGS = { - "test_multigpuinit": MultiGPUInitialize, -} \ No newline at end of file +class MultiGPUOptionsNode: + NodeId = "MultiGPU_Options" + NodeName = "MultiGPU Options" + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "device_index": ("INT", {"default": 0, "min": 0, "max": 64}), + "relative_speed": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.01}) + }, + "optional": { + "gpu_options": ("GPU_OPTIONS",) + } + } + + RETURN_TYPES = ("GPU_OPTIONS",) + FUNCTION = "create_gpu_options" + CATEGORY = "advanced/multigpu" + + def create_gpu_options(self, device_index: int, relative_speed: float, gpu_options: GPUOptionsGroup=None): + if not gpu_options: + gpu_options = GPUOptionsGroup() + gpu_options.clone() + + opt = GPUOptions(device_index=device_index, relative_speed=relative_speed) + gpu_options.add(opt) + + return (gpu_options,) + + +class GPUOptions: + def __init__(self, device_index: int, relative_speed: float): + self.device_index = device_index + self.relative_speed = relative_speed + + def clone(self): + return GPUOptions(self.device_index, self.relative_speed) + + def create_dict(self): + return { + "relative_speed": self.relative_speed + } + +class GPUOptionsGroup: + def __init__(self): + self.options: dict[int, GPUOptions] = {} + + def add(self, info: GPUOptions): + self.options[info.device_index] = info + + def clone(self): + c = GPUOptionsGroup() + for opt in self.options.values(): + c.add(opt) + return c + + def register(self, model: ModelPatcher): + opts_dict = {} + # get devices that are valid for this model + devices: list[torch.device] = [model.load_device] + for extra_model in model.get_additional_models_with_key("multigpu"): + extra_model: ModelPatcher + devices.append(extra_model.load_device) + # create dictionary with actual device mapped to its GPUOptions + device_opts_list: list[GPUOptions] = [] + for device in devices: + device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0)) + opts_dict[device] = device_opts.create_dict() + device_opts_list.append(device_opts) + # make relative_speed relative to 1.0 + max_speed = max([x.relative_speed for x in device_opts_list]) + for value in opts_dict.values(): + value["relative_speed"] /= max_speed + model.model_options["multigpu_options"] = opts_dict + + +node_list = [ + MultiGPUInitialize, + MultiGPUOptionsNode +] +NODE_CLASS_MAPPINGS = {} +NODE_DISPLAY_NAME_MAPPINGS = {} + +for node in node_list: + NODE_CLASS_MAPPINGS[node.NodeId] = node + NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName + +# TODO: remove +NODE_CLASS_MAPPINGS["test_multigpuinit"] = MultiGPUInitialize From eda866bf5113fcbbc03877bcfaa10bb4c24518f9 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 27 Jan 2025 06:25:48 -0600 Subject: [PATCH 30/52] Extracted multigpu core code into multigpu.py, added load_balance_devices to get subdivision of work based on available devices and splittable work item count, added MultiGPU Options nodes to set relative_speed of specific devices; does not change behavior yet --- comfy/multigpu.py | 107 +++++++++++++++++++++++++++++++++ comfy_extras/nodes_multigpu.py | 58 ++---------------- 2 files changed, 113 insertions(+), 52 deletions(-) create mode 100644 comfy/multigpu.py diff --git a/comfy/multigpu.py b/comfy/multigpu.py new file mode 100644 index 000000000..2a1fc29d2 --- /dev/null +++ b/comfy/multigpu.py @@ -0,0 +1,107 @@ +from __future__ import annotations +import torch + +from collections import namedtuple +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher + + +class GPUOptions: + def __init__(self, device_index: int, relative_speed: float): + self.device_index = device_index + self.relative_speed = relative_speed + + def clone(self): + return GPUOptions(self.device_index, self.relative_speed) + + def create_dict(self): + return { + "relative_speed": self.relative_speed + } + +class GPUOptionsGroup: + def __init__(self): + self.options: dict[int, GPUOptions] = {} + + def add(self, info: GPUOptions): + self.options[info.device_index] = info + + def clone(self): + c = GPUOptionsGroup() + for opt in self.options.values(): + c.add(opt) + return c + + def register(self, model: ModelPatcher): + opts_dict = {} + # get devices that are valid for this model + devices: list[torch.device] = [model.load_device] + for extra_model in model.get_additional_models_with_key("multigpu"): + extra_model: ModelPatcher + devices.append(extra_model.load_device) + # create dictionary with actual device mapped to its GPUOptions + device_opts_list: list[GPUOptions] = [] + for device in devices: + device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0)) + opts_dict[device] = device_opts.create_dict() + device_opts_list.append(device_opts) + # make relative_speed relative to 1.0 + min_speed = min([x.relative_speed for x in device_opts_list]) + for value in opts_dict.values(): + value['relative_speed'] /= min_speed + model.model_options['multigpu_options'] = opts_dict + + +LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time']) +def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None): + 'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.' + opts_dict = model_options['multigpu_options'] + devices = list(model_options['multigpu_clones'].keys()) + speed_per_device = [] + work_per_device = [] + # get sum of each device's relative_speed + total_speed = 0.0 + for opts in opts_dict.values(): + total_speed += opts['relative_speed'] + # get relative work for each device; + # obtained by w = (W*r)/R + for device in devices: + relative_speed = opts_dict[device]['relative_speed'] + relative_work = (total_work*relative_speed) / total_speed + speed_per_device.append(relative_speed) + work_per_device.append(relative_work) + # relative work must be expressed in whole numbers, but likely is a decimal; + # perform rounding while maintaining total sum equal to total work (sum of relative works) + work_per_device = round_preserved(work_per_device) + dict_work_per_device = {} + for device, relative_work in zip(devices, work_per_device): + dict_work_per_device[device] = relative_work + if not return_idle_time: + return LoadBalance(dict_work_per_device, None) + # divide relative work by relative speed to get estimated completion time of said work by each device; + # time here is relative and does not correspond to real-world units + completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)] + # calculate relative time spent by the devices waiting on each other after their work is completed + idle_time = abs(min(completion_time) - max(completion_time)) + if work_normalized: + idle_time *= (work_normalized/total_work) + + return LoadBalance(dict_work_per_device, idle_time) + +def round_preserved(values: list[float]): + 'Round all values in a list, preserving the combined sum of values.' + # get floor of values; casting to int does it too + floored = [int(x) for x in values] + total_floored = sum(floored) + # get remainder to distribute + remainder = round(sum(values)) - total_floored + # pair values with fractional portions + fractional = [(i, x-floored[i]) for i, x in enumerate(values)] + # sort by fractional part in descending order + fractional.sort(key=lambda x: x[1], reverse=True) + # distribute the remainder + for i in range(remainder): + index = fractional[i][0] + floored[index] += 1 + return floored diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 2ec1e3cfa..54f68182e 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -1,10 +1,10 @@ from __future__ import annotations -import torch from comfy.model_patcher import ModelPatcher import comfy.utils import comfy.patcher_extension import comfy.model_management +import comfy.multigpu class MultiGPUInitialize: @@ -26,7 +26,7 @@ class MultiGPUInitialize: FUNCTION = "init_multigpu" CATEGORY = "advanced/multigpu" - def init_multigpu(self, model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None): + def init_multigpu(self, model: ModelPatcher, max_gpus: int, gpu_options: comfy.multigpu.GPUOptionsGroup=None): extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) extra_devices = extra_devices[:max_gpus-1] if len(extra_devices) > 0: @@ -39,7 +39,7 @@ class MultiGPUInitialize: multigpu_models.append(device_patcher) model.set_additional_models("multigpu", multigpu_models) if gpu_options is None: - gpu_options = GPUOptionsGroup() + gpu_options = comfy.multigpu.GPUOptionsGroup() gpu_options.register(model) return (model,) @@ -62,63 +62,17 @@ class MultiGPUOptionsNode: FUNCTION = "create_gpu_options" CATEGORY = "advanced/multigpu" - def create_gpu_options(self, device_index: int, relative_speed: float, gpu_options: GPUOptionsGroup=None): + def create_gpu_options(self, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup=None): if not gpu_options: - gpu_options = GPUOptionsGroup() + gpu_options = comfy.multigpu.GPUOptionsGroup() gpu_options.clone() - opt = GPUOptions(device_index=device_index, relative_speed=relative_speed) + opt = comfy.multigpu.GPUOptions(device_index=device_index, relative_speed=relative_speed) gpu_options.add(opt) return (gpu_options,) -class GPUOptions: - def __init__(self, device_index: int, relative_speed: float): - self.device_index = device_index - self.relative_speed = relative_speed - - def clone(self): - return GPUOptions(self.device_index, self.relative_speed) - - def create_dict(self): - return { - "relative_speed": self.relative_speed - } - -class GPUOptionsGroup: - def __init__(self): - self.options: dict[int, GPUOptions] = {} - - def add(self, info: GPUOptions): - self.options[info.device_index] = info - - def clone(self): - c = GPUOptionsGroup() - for opt in self.options.values(): - c.add(opt) - return c - - def register(self, model: ModelPatcher): - opts_dict = {} - # get devices that are valid for this model - devices: list[torch.device] = [model.load_device] - for extra_model in model.get_additional_models_with_key("multigpu"): - extra_model: ModelPatcher - devices.append(extra_model.load_device) - # create dictionary with actual device mapped to its GPUOptions - device_opts_list: list[GPUOptions] = [] - for device in devices: - device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0)) - opts_dict[device] = device_opts.create_dict() - device_opts_list.append(device_opts) - # make relative_speed relative to 1.0 - max_speed = max([x.relative_speed for x in device_opts_list]) - for value in opts_dict.values(): - value["relative_speed"] /= max_speed - model.model_options["multigpu_options"] = opts_dict - - node_list = [ MultiGPUInitialize, MultiGPUOptionsNode From 02747cde7ddacc3fd8a8165cf00aa13cbb770b12 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 29 Jan 2025 11:10:23 -0600 Subject: [PATCH 31/52] Carry over change from _calc_cond_batch into _calc_cond_batch_multigpu --- comfy/samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index b5252d144..f4873e3a5 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -357,7 +357,7 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t default_c.append(x) has_default_conds = True continue - p = comfy.samplers.get_area_and_mult(x, x_in, timestep) + p = get_area_and_mult(x, x_in, timestep) if p is None: continue if p.hooks is not None: From 476aa79b642f7b09c2a7bbe30a0763761eb11fe5 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 6 Feb 2025 08:44:07 -0600 Subject: [PATCH 32/52] Let --cuda-device take in a string to allow multiple devices (or device order) to be chosen, print available devices on startup, potentially support MultiGPU Intel and Ascend setups --- comfy/cli_args.py | 2 +- comfy/model_management.py | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index a92fc0dba..f54be19e4 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -50,7 +50,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.") parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.") parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.") -parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") +parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use.") cm_group = parser.add_mutually_exclusive_group() cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).") cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.") diff --git a/comfy/model_management.py b/comfy/model_management.py index 420eb9e89..477bb0f5f 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -141,6 +141,12 @@ def get_all_torch_devices(exclude_current=False): if is_nvidia(): for i in range(torch.cuda.device_count()): devices.append(torch.device(i)) + elif is_intel_xpu(): + for i in range(torch.xpu.device_count()): + devices.append(torch.device(i)) + elif is_ascend_npu(): + for i in range(torch.npu.device_count()): + devices.append(torch.device(i)) else: devices.append(get_torch_device()) if exclude_current: @@ -320,10 +326,14 @@ def get_torch_device_name(device): return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) try: - logging.info("Device: {}".format(get_torch_device_name(get_torch_device()))) + logging.info("Device [X]: {}".format(get_torch_device_name(get_torch_device()))) except: logging.warning("Could not pick default device.") - +try: + for device in get_all_torch_devices(exclude_current=True): + logging.info("Device [ ]: {}".format(get_torch_device_name(device))) +except: + pass current_loaded_models = [] From 093914a24714ef7264e34062fdeae46bd81964d9 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 3 Mar 2025 22:56:13 -0600 Subject: [PATCH 33/52] Made MultiGPU Work Units node more robust by forcing ModelPatcher clones to match at sample time, reuse loaded MultiGPU clones, finalize MultiGPU Work Units node ID and name, small refactors/cleanup of logging and multigpu-related code --- comfy/model_management.py | 14 ++++--- comfy/model_patcher.py | 67 +++++++++++++++++++++++++++++----- comfy/multigpu.py | 52 ++++++++++++++++++++++++++ comfy/patcher_extension.py | 2 + comfy/sampler_helpers.py | 47 ++++++++++++++++++++++-- comfy/samplers.py | 44 ---------------------- comfy_extras/nodes_multigpu.py | 49 ++++++++++++------------- 7 files changed, 188 insertions(+), 87 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index dd762bdc5..3ee8857c2 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -345,16 +345,16 @@ def get_torch_device_name(device): return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) try: - logging.info("Device [X]: {}".format(get_torch_device_name(get_torch_device()))) + logging.info("Device: {}".format(get_torch_device_name(get_torch_device()))) except: logging.warning("Could not pick default device.") try: for device in get_all_torch_devices(exclude_current=True): - logging.info("Device [ ]: {}".format(get_torch_device_name(device))) + logging.info("Device: {}".format(get_torch_device_name(device))) except: pass -current_loaded_models = [] +current_loaded_models: list[LoadedModel] = [] def module_size(module): module_mem = 0 @@ -1198,7 +1198,7 @@ def soft_empty_cache(force=False): def unload_all_models(): free_memory(1e30, get_torch_device()) -def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True): +def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False): 'Unload only model and its clones - primarily for multigpu cloning purposes.' initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy() additional_models = [] @@ -1218,7 +1218,11 @@ def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True): if skip: continue keep_loaded.append(loaded_model) - free_memory(1e30, get_torch_device(), keep_loaded) + if not all_devices: + free_memory(1e30, get_torch_device(), keep_loaded) + else: + for device in get_all_torch_devices(): + free_memory(1e30, device, keep_loaded) #TODO: might be cleaner to put this somewhere else import threading diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index eb21396be..5ede41dd6 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -243,7 +243,7 @@ class ModelPatcher: self.is_clip = False self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed - self.is_multigpu_clone = False + self.is_multigpu_base_clone = False self.clone_base_uuid = uuid.uuid4() if not hasattr(self.model, 'model_loaded_weight_memory'): @@ -324,14 +324,16 @@ class ModelPatcher: n.is_clip = self.is_clip n.hook_mode = self.hook_mode - n.is_multigpu_clone = self.is_multigpu_clone + n.is_multigpu_base_clone = self.is_multigpu_base_clone n.clone_base_uuid = self.clone_base_uuid for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE): callback(self, n) return n - def multigpu_deepclone(self, new_load_device=None, models_cache: dict[ModelPatcher,ModelPatcher]=None): + def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None): + logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.") + comfy.model_management.unload_model_and_clones(self) n = self.clone() # set load device, if present if new_load_device is not None: @@ -350,19 +352,64 @@ class ModelPatcher: for key, model_list in n.additional_models.items(): for i in range(len(model_list)): add_model = n.additional_models[key][i] - if i not in models_cache: - models_cache[add_model] = add_model.multigpu_deepclone(new_load_device=new_load_device, models_cache=models_cache) - n.additional_models[key][i] = models_cache[add_model] + if add_model.clone_base_uuid not in models_cache: + models_cache[add_model.clone_base_uuid] = add_model.deepclone_multigpu(new_load_device=new_load_device, models_cache=models_cache) + n.additional_models[key][i] = models_cache[add_model.clone_base_uuid] + for callback in self.get_all_callbacks(CallbacksMP.ON_DEEPCLONE_MULTIGPU): + callback(self, n) return n + def match_multigpu_clones(self): + multigpu_models = self.get_additional_models_with_key("multigpu") + if len(multigpu_models) > 0: + new_multigpu_models = [] + for mm in multigpu_models: + # clone main model, but bring over relevant props from existing multigpu clone + n = self.clone() + n.load_device = mm.load_device + n.backup = mm.backup + n.object_patches_backup = mm.object_patches_backup + n.hook_backup = mm.hook_backup + n.model = mm.model + n.is_multigpu_base_clone = mm.is_multigpu_base_clone + n.remove_additional_models("multigpu") + orig_additional_models: dict[str, list[ModelPatcher]] = comfy.patcher_extension.copy_nested_dicts(n.additional_models) + n.additional_models = comfy.patcher_extension.copy_nested_dicts(mm.additional_models) + # figure out which additional models are not present in multigpu clone + models_cache = {} + for mm_add_model in mm.get_additional_models(): + models_cache[mm_add_model.clone_base_uuid] = mm_add_model + remove_models_uuids = set(list(models_cache.keys())) + for key, model_list in orig_additional_models.items(): + for orig_add_model in model_list: + if orig_add_model.clone_base_uuid not in models_cache: + models_cache[orig_add_model.clone_base_uuid] = orig_add_model.deepclone_multigpu(new_load_device=n.load_device, models_cache=models_cache) + existing_list = n.get_additional_models_with_key(key) + existing_list.append(models_cache[orig_add_model.clone_base_uuid]) + n.set_additional_models(key, existing_list) + if orig_add_model.clone_base_uuid in remove_models_uuids: + remove_models_uuids.remove(orig_add_model.clone_base_uuid) + # remove duplicate additional models + for key, model_list in n.additional_models.items(): + new_model_list = [x for x in model_list if x.clone_base_uuid not in remove_models_uuids] + n.set_additional_models(key, new_model_list) + for callback in self.get_all_callbacks(CallbacksMP.ON_MATCH_MULTIGPU_CLONES): + callback(self, n) + new_multigpu_models.append(n) + self.set_additional_models("multigpu", new_multigpu_models) + def is_clone(self, other): if hasattr(other, 'model') and self.model is other.model: return True return False - def clone_has_same_weights(self, clone: 'ModelPatcher'): - if not self.is_clone(clone): - return False + def clone_has_same_weights(self, clone: ModelPatcher, allow_multigpu=False): + if allow_multigpu: + if self.clone_base_uuid != clone.clone_base_uuid: + return False + else: + if not self.is_clone(clone): + return False if self.current_hooks != clone.current_hooks: return False @@ -957,7 +1004,7 @@ class ModelPatcher: return self.additional_models.get(key, []) def get_additional_models(self): - all_models = [] + all_models: list[ModelPatcher] = [] for models in self.additional_models.values(): all_models.extend(models) return all_models diff --git a/comfy/multigpu.py b/comfy/multigpu.py index 2a1fc29d2..9cc8a37fa 100644 --- a/comfy/multigpu.py +++ b/comfy/multigpu.py @@ -1,10 +1,14 @@ from __future__ import annotations import torch +import logging from collections import namedtuple from typing import TYPE_CHECKING if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher +import comfy.utils +import comfy.patcher_extension +import comfy.model_management class GPUOptions: @@ -53,6 +57,53 @@ class GPUOptionsGroup: model.model_options['multigpu_options'] = opts_dict +def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False): + 'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.' + model = model.clone() + # check if multigpu is already prepared - get the load devices from them if possible to exclude + skip_devices = set() + multigpu_models = model.get_additional_models_with_key("multigpu") + if len(multigpu_models) > 0: + for mm in multigpu_models: + skip_devices.add(mm.load_device) + skip_devices = list(skip_devices) + + extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) + extra_devices = extra_devices[:max_gpus-1] + # exclude skipped devices + for skip in skip_devices: + if skip in extra_devices: + extra_devices.remove(skip) + # create new deepclones + if len(extra_devices) > 0: + for device in extra_devices: + device_patcher = None + if reuse_loaded: + # check if there are any ModelPatchers currently loaded that could be referenced here after a clone + loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models() + for lm in loaded_models: + if lm.model is not None and lm.clone_base_uuid == model.clone_base_uuid and lm.load_device == device: + device_patcher = lm.clone() + logging.info(f"Reusing loaded deepclone of {device_patcher.model.__class__.__name__} for {device}") + break + if device_patcher is None: + device_patcher = model.deepclone_multigpu(new_load_device=device) + device_patcher.is_multigpu_base_clone = True + multigpu_models = model.get_additional_models_with_key("multigpu") + multigpu_models.append(device_patcher) + model.set_additional_models("multigpu", multigpu_models) + model.match_multigpu_clones() + if gpu_options is None: + gpu_options = GPUOptionsGroup() + gpu_options.register(model) + else: + logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.") + # persist skip_devices for use in sampling code + # if len(skip_devices) > 0 or "multigpu_skip_devices" in model.model_options: + # model.model_options["multigpu_skip_devices"] = skip_devices + return model + + LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time']) def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None): 'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.' @@ -84,6 +135,7 @@ def load_balance_devices(model_options: dict[str], total_work: int, return_idle_ completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)] # calculate relative time spent by the devices waiting on each other after their work is completed idle_time = abs(min(completion_time) - max(completion_time)) + # if need to compare work idle time, need to normalize to a common total work if work_normalized: idle_time *= (work_normalized/total_work) diff --git a/comfy/patcher_extension.py b/comfy/patcher_extension.py index 859758244..5145855f5 100644 --- a/comfy/patcher_extension.py +++ b/comfy/patcher_extension.py @@ -3,6 +3,8 @@ from typing import Callable class CallbacksMP: ON_CLONE = "on_clone" + ON_DEEPCLONE_MULTIGPU = "on_deepclone_multigpu" + ON_MATCH_MULTIGPU_CLONES = "on_match_multigpu_clones" ON_LOAD = "on_load_after" ON_DETACH = "on_detach_after" ON_CLEANUP = "on_cleanup" diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 40b2021f7..9a97c8559 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -106,16 +106,57 @@ def cleanup_additional_models(models): if hasattr(m, 'cleanup'): m.cleanup() +def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model: ModelPatcher, model_options: dict[str]): + '''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.''' + multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu") + if len(multigpu_models) == 0: + return + extra_devices = [x.load_device for x in multigpu_models] + # handle controlnets + controlnets: set[ControlBase] = set() + for k in conds: + for kk in conds[k]: + if 'control' in kk: + controlnets.add(kk['control']) + if len(controlnets) > 0: + # first, unload all controlnet clones + for cnet in list(controlnets): + cnet_models = cnet.get_models() + for cm in cnet_models: + comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True) + + # next, make sure each controlnet has a deepclone for all relevant devices + for cnet in controlnets: + curr_cnet = cnet + while curr_cnet is not None: + for device in extra_devices: + if device not in curr_cnet.multigpu_clones: + curr_cnet.deepclone_multigpu(device, autoregister=True) + curr_cnet = curr_cnet.previous_controlnet + # since all device clones are now present, recreate the linked list for cloned cnets per device + for cnet in controlnets: + curr_cnet = cnet + while curr_cnet is not None: + prev_cnet = curr_cnet.previous_controlnet + for device in extra_devices: + device_cnet = curr_cnet.get_instance_for_device(device) + prev_device_cnet = None + if prev_cnet is not None: + prev_device_cnet = prev_cnet.get_instance_for_device(device) + device_cnet.set_previous_controlnet(prev_device_cnet) + curr_cnet = prev_cnet + # potentially handle gligen - since not widely used, ignored for now def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None): - real_model: BaseModel = None + model.match_multigpu_clones() + preprocess_multigpu_conds(conds, model, model_options) models, inference_memory = get_additional_models(conds, model.model_dtype()) models += get_additional_models_from_model_options(model_options) models += model.get_nested_additional_models() # TODO: does this require inference_memory update? memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required) - real_model = model.model + real_model: BaseModel = model.model return real_model, conds, models @@ -166,7 +207,7 @@ def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_mo ''' In case multigpu acceleration is enabled, prep ModelPatchers for each device. ''' - multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_clone] + multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_base_clone] if len(multigpu_patchers) > 0: multigpu_dict: dict[torch.device, ModelPatcher] = {} multigpu_dict[model_patcher.load_device] = model_patcher diff --git a/comfy/samplers.py b/comfy/samplers.py index beef0b7e4..d02627d8a 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1088,49 +1088,6 @@ def cast_transformer_options(transformer_options: dict[str], device=None, dtype= for cast in casts: wc_list[i] = wc_list[i].to(cast) - -def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model_options: dict[str], model: ModelPatcher): - '''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.''' - multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu") - if len(multigpu_models) == 0: - return - extra_devices = [x.load_device for x in multigpu_models] - # handle controlnets - controlnets: set[ControlBase] = set() - for k in conds: - for kk in conds[k]: - if 'control' in kk: - controlnets.add(kk['control']) - if len(controlnets) > 0: - # first, unload all controlnet clones - for cnet in list(controlnets): - cnet_models = cnet.get_models() - for cm in cnet_models: - comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True) - - # next, make sure each controlnet has a deepclone for all relevant devices - for cnet in controlnets: - curr_cnet = cnet - while curr_cnet is not None: - for device in extra_devices: - if device not in curr_cnet.multigpu_clones: - curr_cnet.deepclone_multigpu(device, autoregister=True) - curr_cnet = curr_cnet.previous_controlnet - # since all device clones are now present, recreate the linked list for cloned cnets per device - for cnet in controlnets: - curr_cnet = cnet - while curr_cnet is not None: - prev_cnet = curr_cnet.previous_controlnet - for device in extra_devices: - device_cnet = curr_cnet.get_instance_for_device(device) - prev_device_cnet = None - if prev_cnet is not None: - prev_device_cnet = prev_cnet.get_instance_for_device(device) - device_cnet.set_previous_controlnet(prev_device_cnet) - curr_cnet = prev_cnet - # TODO: handle gligen - - class CFGGuider: def __init__(self, model_patcher: ModelPatcher): self.model_patcher = model_patcher @@ -1173,7 +1130,6 @@ class CFGGuider: return self.inner_model.process_latent_out(samples.to(torch.float32)) def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): - preprocess_multigpu_conds(self.conds, self.model_options, self.model_patcher) self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 54f68182e..d1e458b7e 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -1,15 +1,24 @@ from __future__ import annotations +import logging +from inspect import cleandoc -from comfy.model_patcher import ModelPatcher -import comfy.utils -import comfy.patcher_extension -import comfy.model_management +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher import comfy.multigpu -class MultiGPUInitialize: - NodeId = "MultiGPU_Initialize" - NodeName = "MultiGPU Initialize" +class MultiGPUWorkUnitsNode: + """ + Prepares model to have sampling accelerated via splitting work units. + + Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes. + + Other than those exceptions, this node can be placed in any order. + """ + + NodeId = "MultiGPU_WorkUnits" + NodeName = "MultiGPU Work Units" @classmethod def INPUT_TYPES(cls): return { @@ -25,25 +34,17 @@ class MultiGPUInitialize: RETURN_TYPES = ("MODEL",) FUNCTION = "init_multigpu" CATEGORY = "advanced/multigpu" + DESCRIPTION = cleandoc(__doc__) def init_multigpu(self, model: ModelPatcher, max_gpus: int, gpu_options: comfy.multigpu.GPUOptionsGroup=None): - extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) - extra_devices = extra_devices[:max_gpus-1] - if len(extra_devices) > 0: - model = model.clone() - comfy.model_management.unload_model_and_clones(model) - for device in extra_devices: - device_patcher = model.multigpu_deepclone(new_load_device=device) - device_patcher.is_multigpu_clone = True - multigpu_models = model.get_additional_models_with_key("multigpu") - multigpu_models.append(device_patcher) - model.set_additional_models("multigpu", multigpu_models) - if gpu_options is None: - gpu_options = comfy.multigpu.GPUOptionsGroup() - gpu_options.register(model) + model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, gpu_options, reuse_loaded=True) return (model,) class MultiGPUOptionsNode: + """ + Select the relative speed of GPUs in the special case they have significantly different performance from one another. + """ + NodeId = "MultiGPU_Options" NodeName = "MultiGPU Options" @classmethod @@ -61,6 +62,7 @@ class MultiGPUOptionsNode: RETURN_TYPES = ("GPU_OPTIONS",) FUNCTION = "create_gpu_options" CATEGORY = "advanced/multigpu" + DESCRIPTION = cleandoc(__doc__) def create_gpu_options(self, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup=None): if not gpu_options: @@ -74,7 +76,7 @@ class MultiGPUOptionsNode: node_list = [ - MultiGPUInitialize, + MultiGPUWorkUnitsNode, MultiGPUOptionsNode ] NODE_CLASS_MAPPINGS = {} @@ -83,6 +85,3 @@ NODE_DISPLAY_NAME_MAPPINGS = {} for node in node_list: NODE_CLASS_MAPPINGS[node.NodeId] = node NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName - -# TODO: remove -NODE_CLASS_MAPPINGS["test_multigpuinit"] = MultiGPUInitialize From 6dca17bd2dd7455701d5eb466d39d72aa4520b1c Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 3 Mar 2025 23:08:29 -0600 Subject: [PATCH 34/52] Satisfy ruff linting --- comfy/controlnet.py | 6 +++--- comfy/model_management.py | 1 - comfy/multigpu.py | 6 +++--- comfy/samplers.py | 4 ++-- comfy_extras/nodes_multigpu.py | 5 ++--- 5 files changed, 10 insertions(+), 12 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 9bcd1d2e3..14f13bd9d 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -72,7 +72,7 @@ class ControlIsolation: def __enter__(self): self.control.previous_controlnet = None - + def __exit__(self, *args): self.control.previous_controlnet = self.orig_previous_controlnet @@ -151,7 +151,7 @@ class ControlBase: def deepclone_multigpu(self, load_device, autoregister=False): ''' Create deep clone of Control object where model(s) is set to other devices. - + When autoregister is set to True, the deep clone is also added to multigpu_clones dict. ''' raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.") @@ -846,7 +846,7 @@ class T2IAdapter(ControlBase): c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm) self.copy_to(c) return c - + def deepclone_multigpu(self, load_device, autoregister=False): c = self.copy() c.t2i_model = copy.deepcopy(c.t2i_model) diff --git a/comfy/model_management.py b/comfy/model_management.py index 10d0dece2..6e243a437 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -30,7 +30,6 @@ import gc from typing import TYPE_CHECKING if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher - from comfy.model_base import BaseModel class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram diff --git a/comfy/multigpu.py b/comfy/multigpu.py index 9cc8a37fa..aef0b68e8 100644 --- a/comfy/multigpu.py +++ b/comfy/multigpu.py @@ -18,7 +18,7 @@ class GPUOptions: def clone(self): return GPUOptions(self.device_index, self.relative_speed) - + def create_dict(self): return { "relative_speed": self.relative_speed @@ -86,7 +86,7 @@ def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: device_patcher = lm.clone() logging.info(f"Reusing loaded deepclone of {device_patcher.model.__class__.__name__} for {device}") break - if device_patcher is None: + if device_patcher is None: device_patcher = model.deepclone_multigpu(new_load_device=device) device_patcher.is_multigpu_base_clone = True multigpu_models = model.get_additional_models_with_key("multigpu") @@ -138,7 +138,7 @@ def load_balance_devices(model_options: dict[str], total_work: int, return_idle_ # if need to compare work idle time, need to normalize to a common total work if work_normalized: idle_time *= (work_normalized/total_work) - + return LoadBalance(dict_work_per_device, idle_time) def round_preserved(values: list[float]): diff --git a/comfy/samplers.py b/comfy/samplers.py index babfe7a45..bc97f9f71 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -384,7 +384,7 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t devices = [dev_m for dev_m in model_options['multigpu_clones'].keys()] device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {} - + total_conds = 0 for to_run in hooked_to_run.values(): total_conds += len(to_run) @@ -504,7 +504,7 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t new_thread = threading.Thread(target=_handle_batch, args=(device, batch_tuple, results)) threads.append(new_thread) new_thread.start() - + for thread in threads: thread.join() diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index d1e458b7e..3b68c10ff 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -1,5 +1,4 @@ from __future__ import annotations -import logging from inspect import cleandoc from typing import TYPE_CHECKING @@ -11,7 +10,7 @@ import comfy.multigpu class MultiGPUWorkUnitsNode: """ Prepares model to have sampling accelerated via splitting work units. - + Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes. Other than those exceptions, this node can be placed in any order. @@ -30,7 +29,7 @@ class MultiGPUWorkUnitsNode: "gpu_options": ("GPU_OPTIONS",) } } - + RETURN_TYPES = ("MODEL",) FUNCTION = "init_multigpu" CATEGORY = "advanced/multigpu" From 9ce9ff8ef862f23a2486e97f4721fa56c3cea29a Mon Sep 17 00:00:00 2001 From: "kosinkadink1@gmail.com" Date: Fri, 28 Mar 2025 15:29:44 +0800 Subject: [PATCH 35/52] Allow chained MultiGPU Work Unit nodes to affect max_gpus present on ModelPatcher clone --- comfy/multigpu.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/comfy/multigpu.py b/comfy/multigpu.py index aef0b68e8..26edcee90 100644 --- a/comfy/multigpu.py +++ b/comfy/multigpu.py @@ -68,8 +68,9 @@ def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: skip_devices.add(mm.load_device) skip_devices = list(skip_devices) - extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) - extra_devices = extra_devices[:max_gpus-1] + full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) + limit_extra_devices = full_extra_devices[:max_gpus-1] + extra_devices = limit_extra_devices.copy() # exclude skipped devices for skip in skip_devices: if skip in extra_devices: @@ -98,6 +99,13 @@ def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: gpu_options.register(model) else: logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.") + # only keep model clones that don't go 'past' the intended max_gpu count + multigpu_models = model.get_additional_models_with_key("multigpu") + new_multigpu_models = [] + for m in multigpu_models: + if m.load_device in limit_extra_devices: + new_multigpu_models.append(m) + model.set_additional_models("multigpu", new_multigpu_models) # persist skip_devices for use in sampling code # if len(skip_devices) > 0 or "multigpu_skip_devices" in model.model_options: # model.model_options["multigpu_skip_devices"] = skip_devices From 407a5a656f103c42497f5938a80d0771712b8613 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 28 Mar 2025 02:48:11 -0500 Subject: [PATCH 36/52] Rollback core of last commit due to weird behavior --- comfy/multigpu.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/comfy/multigpu.py b/comfy/multigpu.py index 26edcee90..90995a5ab 100644 --- a/comfy/multigpu.py +++ b/comfy/multigpu.py @@ -99,13 +99,13 @@ def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: gpu_options.register(model) else: logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.") - # only keep model clones that don't go 'past' the intended max_gpu count - multigpu_models = model.get_additional_models_with_key("multigpu") - new_multigpu_models = [] - for m in multigpu_models: - if m.load_device in limit_extra_devices: - new_multigpu_models.append(m) - model.set_additional_models("multigpu", new_multigpu_models) + # TODO: only keep model clones that don't go 'past' the intended max_gpu count + # multigpu_models = model.get_additional_models_with_key("multigpu") + # new_multigpu_models = [] + # for m in multigpu_models: + # if m.load_device in limit_extra_devices: + # new_multigpu_models.append(m) + # model.set_additional_models("multigpu", new_multigpu_models) # persist skip_devices for use in sampling code # if len(skip_devices) > 0 or "multigpu_skip_devices" in model.model_options: # model.model_options["multigpu_skip_devices"] = skip_devices From 8be711715c471db68f9cea15989b5ec0f2ac2e7d Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 19 Apr 2025 17:35:54 -0500 Subject: [PATCH 37/52] Make unload_all_models account for all devices --- comfy/model_management.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 90785c2c5..88c1c0a12 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1272,7 +1272,8 @@ def soft_empty_cache(force=False): torch.cuda.ipc_collect() def unload_all_models(): - free_memory(1e30, get_torch_device()) + for device in get_all_torch_devices(): + free_memory(1e30, device) def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False): 'Unload only model and its clones - primarily for multigpu cloning purposes.' From 44e053c26dc8982e88973a253eef51b9a9a91302 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 24 Jun 2025 00:48:51 -0500 Subject: [PATCH 38/52] Improve error handling for multigpu threads --- comfy/samplers.py | 143 +++++++++++++++++++++++++--------------------- 1 file changed, 78 insertions(+), 65 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 260527661..90cce078d 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -3,7 +3,7 @@ from __future__ import annotations import comfy.model_management from .k_diffusion import sampling as k_diffusion_sampling from .extra_samplers import uni_pc -from typing import TYPE_CHECKING, Callable, NamedTuple +from typing import TYPE_CHECKING, Callable, NamedTuple, Any if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher from comfy.model_base import BaseModel @@ -428,74 +428,85 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t if batched_to_run_length >= conds_per_device: index_device += 1 - thread_result = collections.namedtuple('thread_result', ['output', 'mult', 'area', 'batch_chunks', 'cond_or_uncond']) + class thread_result(NamedTuple): + output: Any + mult: Any + area: Any + batch_chunks: int + cond_or_uncond: Any + error: Exception = None + def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]): - model_current: BaseModel = model_options["multigpu_clones"][device].model - # run every hooked_to_run separately - with torch.no_grad(): - for hooks, to_batch in batch_tuple: - input_x = [] - mult = [] - c = [] - cond_or_uncond = [] - uuids = [] - area = [] - control: ControlBase = None - patches = None - for x in to_batch: - o = x - p = o[0] - input_x.append(p.input_x) - mult.append(p.mult) - c.append(p.conditioning) - area.append(p.area) - cond_or_uncond.append(o[1]) - uuids.append(p.uuid) - control = p.control - patches = p.patches + try: + model_current: BaseModel = model_options["multigpu_clones"][device].model + # run every hooked_to_run separately + with torch.no_grad(): + for hooks, to_batch in batch_tuple: + input_x = [] + mult = [] + c = [] + cond_or_uncond = [] + uuids = [] + area = [] + control: ControlBase = None + patches = None + for x in to_batch: + o = x + p = o[0] + input_x.append(p.input_x) + mult.append(p.mult) + c.append(p.conditioning) + area.append(p.area) + cond_or_uncond.append(o[1]) + uuids.append(p.uuid) + control = p.control + patches = p.patches - batch_chunks = len(cond_or_uncond) - input_x = torch.cat(input_x).to(device) - c = cond_cat(c, device=device) - timestep_ = torch.cat([timestep.to(device)] * batch_chunks) + batch_chunks = len(cond_or_uncond) + input_x = torch.cat(input_x).to(device) + c = cond_cat(c, device=device) + timestep_ = torch.cat([timestep.to(device)] * batch_chunks) - transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks) - if 'transformer_options' in model_options: - transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options, - model_options['transformer_options'], - copy_dict1=False) + transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks) + if 'transformer_options' in model_options: + transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options, + model_options['transformer_options'], + copy_dict1=False) - if patches is not None: - # TODO: replace with merge_nested_dicts function - if "patches" in transformer_options: - cur_patches = transformer_options["patches"].copy() - for p in patches: - if p in cur_patches: - cur_patches[p] = cur_patches[p] + patches[p] - else: - cur_patches[p] = patches[p] - transformer_options["patches"] = cur_patches + if patches is not None: + # TODO: replace with merge_nested_dicts function + if "patches" in transformer_options: + cur_patches = transformer_options["patches"].copy() + for p in patches: + if p in cur_patches: + cur_patches[p] = cur_patches[p] + patches[p] + else: + cur_patches[p] = patches[p] + transformer_options["patches"] = cur_patches + else: + transformer_options["patches"] = patches + + transformer_options["cond_or_uncond"] = cond_or_uncond[:] + transformer_options["uuids"] = uuids[:] + transformer_options["sigmas"] = timestep + transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device) + transformer_options["multigpu_thread_device"] = device + + cast_transformer_options(transformer_options, device=device) + c['transformer_options'] = transformer_options + + if control is not None: + device_control = control.get_instance_for_device(device) + c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options) + + if 'model_function_wrapper' in model_options: + output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks) else: - transformer_options["patches"] = patches - - transformer_options["cond_or_uncond"] = cond_or_uncond[:] - transformer_options["uuids"] = uuids[:] - transformer_options["sigmas"] = timestep - transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device) - transformer_options["multigpu_thread_device"] = device - - cast_transformer_options(transformer_options, device=device) - c['transformer_options'] = transformer_options - - if control is not None: - device_control = control.get_instance_for_device(device) - c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options) - - if 'model_function_wrapper' in model_options: - output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks) - else: - output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks) - results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond)) + output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks) + results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond)) + except Exception as e: + results.append(thread_result(None, None, None, None, None, error=e)) + raise results: list[thread_result] = [] @@ -508,7 +519,9 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t for thread in threads: thread.join() - for output, mult, area, batch_chunks, cond_or_uncond in results: + for output, mult, area, batch_chunks, cond_or_uncond, error in results: + if error is not None: + raise error for o in range(batch_chunks): cond_index = cond_or_uncond[o] a = area[o] From d89dd5f0b04c09b01926002244280d98590f02fe Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 13 Oct 2025 22:00:34 -0700 Subject: [PATCH 39/52] Satisfy ruff --- comfy/sampler_helpers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index c43bd3bac..9aa9fa28a 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -12,7 +12,6 @@ import comfy.patcher_extension from typing import TYPE_CHECKING if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher - from comfy.model_base import BaseModel from comfy.controlnet import ControlBase def prepare_mask(noise_mask, shape, device): From 4661d1db5aa774f972bc270f2a1e5f8cf20ea978 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 15 Oct 2025 17:34:36 -0700 Subject: [PATCH 40/52] Bring patches changes from _calc_cond_batch into _calc_cond_batch_multigpu --- comfy/samplers.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index e0e0296f8..ed702304c 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -481,17 +481,10 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t copy_dict1=False) if patches is not None: - # TODO: replace with merge_nested_dicts function - if "patches" in transformer_options: - cur_patches = transformer_options["patches"].copy() - for p in patches: - if p in cur_patches: - cur_patches[p] = cur_patches[p] + patches[p] - else: - cur_patches[p] = patches[p] - transformer_options["patches"] = cur_patches - else: - transformer_options["patches"] = patches + transformer_options["patches"] = comfy.patcher_extension.merge_nested_dicts( + transformer_options.get("patches", {}), + patches + ) transformer_options["cond_or_uncond"] = cond_or_uncond[:] transformer_options["uuids"] = uuids[:] From f4b99bc62389af315013dda85f24f2bbd262b686 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 17 Feb 2026 04:55:00 -0800 Subject: [PATCH 41/52] Made multigpu deepclone load model from disk to avoid needing to deepclone actual model object, fixed issues with merge, turn off cuda backend as it causes device mismatch issue with rope (and potentially other ops), will investigate --- comfy/model_patcher.py | 11 ++++++++++- comfy/quant_ops.py | 2 +- comfy/samplers.py | 4 ++-- comfy/sd.py | 2 ++ 4 files changed, 15 insertions(+), 4 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index d0110c7c6..aa7b862e7 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -23,6 +23,7 @@ import inspect import logging import math import uuid +import copy from typing import Callable, Optional import torch @@ -274,6 +275,7 @@ class ModelPatcher: self.is_clip = False self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed + self.cached_patcher_init: tuple[Callable, tuple] | tuple[Callable, tuple, int] | None = None self.is_multigpu_base_clone = False self.clone_base_uuid = uuid.uuid4() @@ -368,6 +370,7 @@ class ModelPatcher: n.is_clip = self.is_clip n.hook_mode = self.hook_mode + n.cached_patcher_init = self.cached_patcher_init n.is_multigpu_base_clone = self.is_multigpu_base_clone n.clone_base_uuid = self.clone_base_uuid @@ -382,12 +385,18 @@ class ModelPatcher: # set load device, if present if new_load_device is not None: n.load_device = new_load_device + if self.cached_patcher_init is not None: + temp_model_patcher: ModelPatcher | list[ModelPatcher] = self.cached_patcher_init[0](*self.cached_patcher_init[1]) + if len(self.cached_patcher_init) > 2: + temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]] + n.model = temp_model_patcher.model + else: + n.model = copy.deepcopy(n.model) # unlike for normal clone, backup dicts that shared same ref should not; # otherwise, patchers that have deep copies of base models will erroneously influence each other. n.backup = copy.deepcopy(n.backup) n.object_patches_backup = copy.deepcopy(n.object_patches_backup) n.hook_backup = copy.deepcopy(n.hook_backup) - n.model = copy.deepcopy(n.model) # multigpu clone should not have multigpu additional_models entry n.remove_additional_models("multigpu") # multigpu_clone all stored additional_models; make sure circular references are properly handled diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 15a4f457b..d8addefd8 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -20,7 +20,7 @@ try: if cuda_version < (13,): ck.registry.disable("cuda") logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.") - + ck.registry.disable("cuda") # multigpu will not work rn with comfy-kitchen on cuda backend ck.registry.disable("triton") for k, v in ck.list_backends().items(): logging.info(f"Found comfy_kitchen backend {k}: {v}") diff --git a/comfy/samplers.py b/comfy/samplers.py index 3f5a699d9..5dee49e7e 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -418,7 +418,7 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t to_batch_temp.reverse() to_batch = to_batch_temp[:1] - free_memory = model_management.get_free_memory(current_device) + free_memory = comfy.model_management.get_free_memory(current_device) for i in range(1, len(to_batch_temp) + 1): batch_amount = to_batch_temp[:len(to_batch_temp)//i] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] @@ -487,7 +487,7 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t transformer_options["cond_or_uncond"] = cond_or_uncond[:] transformer_options["uuids"] = uuids[:] - transformer_options["sigmas"] = timestep + transformer_options["sigmas"] = timestep.to(device) transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device) transformer_options["multigpu_thread_device"] = device diff --git a/comfy/sd.py b/comfy/sd.py index f65e7cadd..2643de26d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1510,6 +1510,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata) if out is None: raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd))) + out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0) return out def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None): @@ -1711,6 +1712,7 @@ def load_diffusion_model(unet_path, model_options={}): if model is None: logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd))) + model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options)) return model def load_unet(unet_path, dtype=None): From 84f465e791f4957921b1452fc239fa6794c96f22 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 30 Mar 2026 07:07:54 -0700 Subject: [PATCH 42/52] Set CUDA device at start of multigpu threads to avoid multithreading bugs Amp-Thread-ID: https://ampcode.com/threads/T-019d3ee9-19d5-767a-9d7a-e50cbbef815b Co-authored-by: Amp --- comfy/samplers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/samplers.py b/comfy/samplers.py index ab691ed5b..1ff50f51d 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -444,6 +444,7 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]): try: + torch.cuda.set_device(device) model_current: BaseModel = model_options["multigpu_clones"][device].model # run every hooked_to_run separately with torch.no_grad(): From d52dcbc88fa225707bc18269da69b7c18cbbf5b3 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 30 Mar 2026 07:23:13 -0700 Subject: [PATCH 43/52] Rewrite multigpu nodes to V3 format Amp-Thread-ID: https://ampcode.com/threads/T-019d3ee9-19d5-767a-9d7a-e50cbbef815b Co-authored-by: Amp --- comfy_extras/nodes_multigpu.py | 104 +++++++++++++++++---------------- 1 file changed, 54 insertions(+), 50 deletions(-) diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 3b68c10ff..789038b1d 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -1,13 +1,17 @@ from __future__ import annotations -from inspect import cleandoc +from inspect import cleandoc from typing import TYPE_CHECKING +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher import comfy.multigpu -class MultiGPUWorkUnitsNode: +class MultiGPUWorkUnitsNode(io.ComfyNode): """ Prepares model to have sampling accelerated via splitting work units. @@ -16,54 +20,53 @@ class MultiGPUWorkUnitsNode: Other than those exceptions, this node can be placed in any order. """ - NodeId = "MultiGPU_WorkUnits" - NodeName = "MultiGPU Work Units" @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "model": ("MODEL",), - "max_gpus" : ("INT", {"default": 8, "min": 1, "step": 1}), - }, - "optional": { - "gpu_options": ("GPU_OPTIONS",) - } - } + def define_schema(cls): + return io.Schema( + node_id="MultiGPU_WorkUnits", + display_name="MultiGPU Work Units", + category="advanced/multigpu", + description=cleandoc(cls.__doc__), + inputs=[ + io.Model.Input("model"), + io.Int.Input("max_gpus", default=8, min=1, step=1), + io.Custom("GPU_OPTIONS").Input("gpu_options", optional=True), + ], + outputs=[ + io.Model.Output(), + ], + ) - RETURN_TYPES = ("MODEL",) - FUNCTION = "init_multigpu" - CATEGORY = "advanced/multigpu" - DESCRIPTION = cleandoc(__doc__) - - def init_multigpu(self, model: ModelPatcher, max_gpus: int, gpu_options: comfy.multigpu.GPUOptionsGroup=None): + @classmethod + def execute(cls, model: ModelPatcher, max_gpus: int, gpu_options: comfy.multigpu.GPUOptionsGroup = None) -> io.NodeOutput: model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, gpu_options, reuse_loaded=True) - return (model,) + return io.NodeOutput(model) -class MultiGPUOptionsNode: + +class MultiGPUOptionsNode(io.ComfyNode): """ Select the relative speed of GPUs in the special case they have significantly different performance from one another. """ - NodeId = "MultiGPU_Options" - NodeName = "MultiGPU Options" @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "device_index": ("INT", {"default": 0, "min": 0, "max": 64}), - "relative_speed": ("FLOAT", {"default": 1.0, "min": 0.0, "step": 0.01}) - }, - "optional": { - "gpu_options": ("GPU_OPTIONS",) - } - } + def define_schema(cls): + return io.Schema( + node_id="MultiGPU_Options", + display_name="MultiGPU Options", + category="advanced/multigpu", + description=cleandoc(cls.__doc__), + inputs=[ + io.Int.Input("device_index", default=0, min=0, max=64), + io.Float.Input("relative_speed", default=1.0, min=0.0, step=0.01), + io.Custom("GPU_OPTIONS").Input("gpu_options", optional=True), + ], + outputs=[ + io.Custom("GPU_OPTIONS").Output(), + ], + ) - RETURN_TYPES = ("GPU_OPTIONS",) - FUNCTION = "create_gpu_options" - CATEGORY = "advanced/multigpu" - DESCRIPTION = cleandoc(__doc__) - - def create_gpu_options(self, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup=None): + @classmethod + def execute(cls, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup = None) -> io.NodeOutput: if not gpu_options: gpu_options = comfy.multigpu.GPUOptionsGroup() gpu_options.clone() @@ -71,16 +74,17 @@ class MultiGPUOptionsNode: opt = comfy.multigpu.GPUOptions(device_index=device_index, relative_speed=relative_speed) gpu_options.add(opt) - return (gpu_options,) + return io.NodeOutput(gpu_options) -node_list = [ - MultiGPUWorkUnitsNode, - MultiGPUOptionsNode -] -NODE_CLASS_MAPPINGS = {} -NODE_DISPLAY_NAME_MAPPINGS = {} +class MultiGPUExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + MultiGPUWorkUnitsNode, + MultiGPUOptionsNode, + ] -for node in node_list: - NODE_CLASS_MAPPINGS[node.NodeId] = node - NODE_DISPLAY_NAME_MAPPINGS[node.NodeId] = node.NodeName + +async def comfy_entrypoint() -> MultiGPUExtension: + return MultiGPUExtension() From 5f4fcd19e7a5ce82b998495d18c10f4a111e41b7 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 30 Mar 2026 07:30:32 -0700 Subject: [PATCH 44/52] Simplify multigpu nodes: default max_gpus=2, remove gpu_options input, disable Options node Amp-Thread-ID: https://ampcode.com/threads/T-019d3ee9-19d5-767a-9d7a-e50cbbef815b Co-authored-by: Amp --- comfy_extras/nodes_multigpu.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 789038b1d..c77dd5c1f 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -29,8 +29,7 @@ class MultiGPUWorkUnitsNode(io.ComfyNode): description=cleandoc(cls.__doc__), inputs=[ io.Model.Input("model"), - io.Int.Input("max_gpus", default=8, min=1, step=1), - io.Custom("GPU_OPTIONS").Input("gpu_options", optional=True), + io.Int.Input("max_gpus", default=2, min=1, step=1), ], outputs=[ io.Model.Output(), @@ -38,8 +37,8 @@ class MultiGPUWorkUnitsNode(io.ComfyNode): ) @classmethod - def execute(cls, model: ModelPatcher, max_gpus: int, gpu_options: comfy.multigpu.GPUOptionsGroup = None) -> io.NodeOutput: - model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, gpu_options, reuse_loaded=True) + def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput: + model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True) return io.NodeOutput(model) @@ -82,7 +81,7 @@ class MultiGPUExtension(ComfyExtension): async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ MultiGPUWorkUnitsNode, - MultiGPUOptionsNode, + # MultiGPUOptionsNode, ] From 1d8e379f41154354edf7879d21606cd8dabd575a Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 30 Mar 2026 08:00:20 -0700 Subject: [PATCH 45/52] Rename MultiGPU Work Units to MultiGPU CFG Split Amp-Thread-ID: https://ampcode.com/threads/T-019d3ee9-19d5-767a-9d7a-e50cbbef815b Co-authored-by: Amp --- comfy_extras/nodes_multigpu.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index c77dd5c1f..5d24952bf 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: import comfy.multigpu -class MultiGPUWorkUnitsNode(io.ComfyNode): +class MultiGPUCFGSplitNode(io.ComfyNode): """ Prepares model to have sampling accelerated via splitting work units. @@ -24,7 +24,7 @@ class MultiGPUWorkUnitsNode(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="MultiGPU_WorkUnits", - display_name="MultiGPU Work Units", + display_name="MultiGPU CFG Split", category="advanced/multigpu", description=cleandoc(cls.__doc__), inputs=[ @@ -80,7 +80,7 @@ class MultiGPUExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ - MultiGPUWorkUnitsNode, + MultiGPUCFGSplitNode, # MultiGPUOptionsNode, ] From afdddcee66cb80b81bdc071da3773a54652d1284 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 30 Mar 2026 08:32:52 -0700 Subject: [PATCH 46/52] Re-enable comfy-kitchen cuda backend for multigpu testing Amp-Thread-ID: https://ampcode.com/threads/T-019d3f5c-28c5-72c9-abed-34681f1b54ba Co-authored-by: Amp --- comfy/quant_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 9375255d1..37e546722 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -20,7 +20,6 @@ try: if cuda_version < (13,): ck.registry.disable("cuda") logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.") - ck.registry.disable("cuda") # multigpu will not work rn with comfy-kitchen on cuda backend ck.registry.disable("triton") for k, v in ck.list_backends().items(): logging.info(f"Found comfy_kitchen backend {k}: {v}") From 3fab720be9123a710578b94a89f94d80f9601761 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 30 Mar 2026 08:45:55 -0700 Subject: [PATCH 47/52] Add debug logging for device mismatch in ModelPatcherDynamic.load Amp-Thread-ID: https://ampcode.com/threads/T-019d3f5c-28c5-72c9-abed-34681f1b54ba Co-authored-by: Amp --- comfy/model_management.py | 2 ++ comfy/model_patcher.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index c89f7a246..3e58e7dd9 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -639,6 +639,8 @@ class LoadedModel: return True def model_use_more_vram(self, extra_memory, force_patch_weights=False): + if self.device != self.model.load_device: + logging.error(f"LoadedModel device mismatch: self.device={self.device}, model.load_device={self.model.load_device}, model_class={self.model.model.__class__.__name__}, is_multigpu={getattr(self.model, 'is_multigpu_base_clone', False)}, id(model)={id(self.model)}") return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights) def __eq__(self, other): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index c3ecc276f..a3872926d 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1646,6 +1646,8 @@ class ModelPatcherDynamic(ModelPatcher): #now. assert not full_load + if device_to != self.load_device: + logging.error(f"ModelPatcherDynamic.load device mismatch: device_to={device_to}, self.load_device={self.load_device}, model_class={self.model.__class__.__name__}, is_multigpu_base_clone={getattr(self, 'is_multigpu_base_clone', False)}, id(self)={id(self)}") assert device_to == self.load_device num_patches = 0 From 20803749c3be2666d2ee34f0371c6f483a792b5d Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 30 Mar 2026 08:53:36 -0700 Subject: [PATCH 48/52] Add detailed multigpu debug logging to load_models_gpu Amp-Thread-ID: https://ampcode.com/threads/T-019d3f5c-28c5-72c9-abed-34681f1b54ba Co-authored-by: Amp --- comfy/model_management.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 3e58e7dd9..76d475c0d 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -780,16 +780,19 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu if loaded_model_index is not None: loaded = current_loaded_models[loaded_model_index] loaded.currently_used = True + logging.info(f"[MULTIGPU_DBG] Reusing LoadedModel for {x.model.__class__.__name__}: LoadedModel.device={loaded.device}, model.load_device={loaded.model.load_device}, is_multigpu={getattr(loaded.model, 'is_multigpu_base_clone', False)}, id(patcher)={id(loaded.model)}, id(inner)={id(loaded.model.model)}") models_to_load.append(loaded) else: if hasattr(x, "model"): logging.info(f"Requested to load {x.model.__class__.__name__}") + logging.info(f"[MULTIGPU_DBG] New LoadedModel for {x.model.__class__.__name__}: LoadedModel.device={loaded_model.device}, model.load_device={x.load_device}, is_multigpu={getattr(x, 'is_multigpu_base_clone', False)}, id(patcher)={id(x)}, id(inner)={id(x.model)}") models_to_load.append(loaded_model) for loaded_model in models_to_load: to_unload = [] for i in range(len(current_loaded_models)): if loaded_model.model.is_clone(current_loaded_models[i].model): + logging.info(f"[MULTIGPU_DBG] is_clone match: unloading idx={i}, LoadedModel.device={current_loaded_models[i].device}, model.load_device={current_loaded_models[i].model.load_device}, id(inner)={id(current_loaded_models[i].model.model)}") to_unload = [i] + to_unload for i in to_unload: model_to_unload = current_loaded_models.pop(i) From b418fb1582946578ca04daf0aeeee76955e79c7a Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 30 Mar 2026 08:56:33 -0700 Subject: [PATCH 49/52] Fix device mismatch: update LoadedModel.device when _switch_parent swaps to parent patcher When a multigpu clone ModelPatcher is garbage collected, LoadedModel._switch_parent switches the weakref to point at the parent (main) ModelPatcher. However, it was not updating LoadedModel.device, leaving it with the old clone's device (e.g., cuda:1). On subsequent runs, this stale device was passed to ModelPatcherDynamic.load(), causing an assertion failure (device_to != self.load_device). Amp-Thread-ID: https://ampcode.com/threads/T-019d3f5c-28c5-72c9-abed-34681f1b54ba Co-authored-by: Amp --- comfy/model_management.py | 6 +----- comfy/model_patcher.py | 2 -- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 76d475c0d..14d9f80fb 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -577,6 +577,7 @@ class LoadedModel: model = self._parent_model() if model is not None: self._set_model(model) + self.device = model.load_device @property def model(self): @@ -639,8 +640,6 @@ class LoadedModel: return True def model_use_more_vram(self, extra_memory, force_patch_weights=False): - if self.device != self.model.load_device: - logging.error(f"LoadedModel device mismatch: self.device={self.device}, model.load_device={self.model.load_device}, model_class={self.model.model.__class__.__name__}, is_multigpu={getattr(self.model, 'is_multigpu_base_clone', False)}, id(model)={id(self.model)}") return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights) def __eq__(self, other): @@ -780,19 +779,16 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu if loaded_model_index is not None: loaded = current_loaded_models[loaded_model_index] loaded.currently_used = True - logging.info(f"[MULTIGPU_DBG] Reusing LoadedModel for {x.model.__class__.__name__}: LoadedModel.device={loaded.device}, model.load_device={loaded.model.load_device}, is_multigpu={getattr(loaded.model, 'is_multigpu_base_clone', False)}, id(patcher)={id(loaded.model)}, id(inner)={id(loaded.model.model)}") models_to_load.append(loaded) else: if hasattr(x, "model"): logging.info(f"Requested to load {x.model.__class__.__name__}") - logging.info(f"[MULTIGPU_DBG] New LoadedModel for {x.model.__class__.__name__}: LoadedModel.device={loaded_model.device}, model.load_device={x.load_device}, is_multigpu={getattr(x, 'is_multigpu_base_clone', False)}, id(patcher)={id(x)}, id(inner)={id(x.model)}") models_to_load.append(loaded_model) for loaded_model in models_to_load: to_unload = [] for i in range(len(current_loaded_models)): if loaded_model.model.is_clone(current_loaded_models[i].model): - logging.info(f"[MULTIGPU_DBG] is_clone match: unloading idx={i}, LoadedModel.device={current_loaded_models[i].device}, model.load_device={current_loaded_models[i].model.load_device}, id(inner)={id(current_loaded_models[i].model.model)}") to_unload = [i] + to_unload for i in to_unload: model_to_unload = current_loaded_models.pop(i) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index a3872926d..c3ecc276f 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1646,8 +1646,6 @@ class ModelPatcherDynamic(ModelPatcher): #now. assert not full_load - if device_to != self.load_device: - logging.error(f"ModelPatcherDynamic.load device mismatch: device_to={device_to}, self.load_device={self.load_device}, model_class={self.model.__class__.__name__}, is_multigpu_base_clone={getattr(self, 'is_multigpu_base_clone', False)}, id(self)={id(self)}") assert device_to == self.load_device num_patches = 0 From 4b93c4360f4d09fa6f3a360fbf74c858c86f091a Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 8 Apr 2026 02:39:07 -1000 Subject: [PATCH 50/52] Implement persistent thread pool for multi-GPU CFG splitting (#13329) Replace per-step thread create/destroy in _calc_cond_batch_multigpu with a persistent MultiGPUThreadPool. Each worker thread calls torch.cuda.set_device() once at startup, preserving compiled kernel caches across diffusion steps. - Add MultiGPUThreadPool class in comfy/multigpu.py - Create pool in CFGGuider.outer_sample(), shut down in finally block - Main thread handles its own device batch directly for zero overhead - Falls back to sequential execution if no pool is available --- comfy/multigpu.py | 63 ++++++++++++++++++++++++++++++++++++++++ comfy/sampler_helpers.py | 1 + comfy/samplers.py | 57 +++++++++++++++++++++++++++--------- 3 files changed, 108 insertions(+), 13 deletions(-) diff --git a/comfy/multigpu.py b/comfy/multigpu.py index 90995a5ab..096270c12 100644 --- a/comfy/multigpu.py +++ b/comfy/multigpu.py @@ -1,4 +1,6 @@ from __future__ import annotations +import queue +import threading import torch import logging @@ -11,6 +13,67 @@ import comfy.patcher_extension import comfy.model_management +class MultiGPUThreadPool: + """Persistent thread pool for multi-GPU work distribution. + + Maintains one worker thread per extra GPU device. Each thread calls + torch.cuda.set_device() once at startup so that compiled kernel caches + (inductor/triton) stay warm across diffusion steps. + """ + + def __init__(self, devices: list[torch.device]): + self._workers: list[threading.Thread] = [] + self._work_queues: dict[torch.device, queue.Queue] = {} + self._result_queues: dict[torch.device, queue.Queue] = {} + + for device in devices: + wq = queue.Queue() + rq = queue.Queue() + self._work_queues[device] = wq + self._result_queues[device] = rq + t = threading.Thread(target=self._worker_loop, args=(device, wq, rq), daemon=True) + t.start() + self._workers.append(t) + + def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue): + try: + torch.cuda.set_device(device) + except Exception as e: + logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}") + while True: + item = work_q.get() + if item is None: + return + result_q.put((None, e)) + return + while True: + item = work_q.get() + if item is None: + break + fn, args, kwargs = item + try: + result = fn(*args, **kwargs) + result_q.put((result, None)) + except Exception as e: + result_q.put((None, e)) + + def submit(self, device: torch.device, fn, *args, **kwargs): + self._work_queues[device].put((fn, args, kwargs)) + + def get_result(self, device: torch.device): + return self._result_queues[device].get() + + @property + def devices(self) -> list[torch.device]: + return list(self._work_queues.keys()) + + def shutdown(self): + for wq in self._work_queues.values(): + wq.put(None) # sentinel + for t in self._workers: + t.join(timeout=5.0) + + class GPUOptions: def __init__(self, device_index: int, relative_speed: float): self.device_index = device_index diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 844fadacd..6f5447d95 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -11,6 +11,7 @@ import comfy.hooks import comfy.patcher_extension from typing import TYPE_CHECKING if TYPE_CHECKING: + from comfy.model_base import BaseModel from comfy.model_patcher import ModelPatcher from comfy.controlnet import ControlBase diff --git a/comfy/samplers.py b/comfy/samplers.py index 1ff50f51d..68f093749 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -18,10 +18,10 @@ import comfy.model_patcher import comfy.patcher_extension import comfy.hooks import comfy.context_windows +import comfy.multigpu import comfy.utils import scipy.stats import numpy -import threading def add_area_dims(area, num_dims): @@ -509,15 +509,38 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t raise - results: list[thread_result] = [] - threads: list[threading.Thread] = [] - for device, batch_tuple in device_batched_hooked_to_run.items(): - new_thread = threading.Thread(target=_handle_batch, args=(device, batch_tuple, results)) - threads.append(new_thread) - new_thread.start() + def _handle_batch_pooled(device, batch_tuple): + worker_results = [] + _handle_batch(device, batch_tuple, worker_results) + return worker_results - for thread in threads: - thread.join() + results: list[thread_result] = [] + thread_pool: comfy.multigpu.MultiGPUThreadPool = model_options.get("multigpu_thread_pool") + main_device = output_device + main_batch_tuple = None + + # Submit extra GPU work to pool first, then run main device on this thread + pool_devices = [] + for device, batch_tuple in device_batched_hooked_to_run.items(): + if device == main_device and thread_pool is not None: + main_batch_tuple = batch_tuple + elif thread_pool is not None: + thread_pool.submit(device, _handle_batch_pooled, device, batch_tuple) + pool_devices.append(device) + else: + # Fallback: no pool, run everything on main thread + _handle_batch(device, batch_tuple, results) + + # Run main device batch on this thread (parallel with pool workers) + if main_batch_tuple is not None: + _handle_batch(main_device, main_batch_tuple, results) + + # Collect results from pool workers + for device in pool_devices: + worker_results, error = thread_pool.get_result(device) + if error is not None: + raise error + results.extend(worker_results) for output, mult, area, batch_chunks, cond_or_uncond, error in results: if error is not None: @@ -1187,17 +1210,25 @@ class CFGGuider: multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options) - noise = noise.to(device=device, dtype=torch.float32) - latent_image = latent_image.to(device=device, dtype=torch.float32) - sigmas = sigmas.to(device) - cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype()) + # Create persistent thread pool for extra GPU devices + if multigpu_patchers: + extra_devices = [p.load_device for p in multigpu_patchers] + self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(extra_devices) try: + noise = noise.to(device=device, dtype=torch.float32) + latent_image = latent_image.to(device=device, dtype=torch.float32) + sigmas = sigmas.to(device) + cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype()) + self.model_patcher.pre_run() for multigpu_patcher in multigpu_patchers: multigpu_patcher.pre_run() output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) finally: + thread_pool = self.model_options.pop("multigpu_thread_pool", None) + if thread_pool is not None: + thread_pool.shutdown() self.model_patcher.cleanup() for multigpu_patcher in multigpu_patchers: multigpu_patcher.cleanup() From 48deb15c0e2b3336de4ca27b3e920954dfde453b Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 8 Apr 2026 22:15:57 -1000 Subject: [PATCH 51/52] Simplify multigpu dispatch: run all devices on pool threads (#13340) Benchmarked hybrid (main thread + pool) vs all-pool on 2x RTX 4090 with SD1.5 and NetaYume models. No meaningful performance difference (within noise). All-pool is simpler: eliminates the main_device special case, main_batch_tuple deferred execution, and the 3-way branch in the dispatch loop. --- comfy/samplers.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 68f093749..8ebf1c496 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -516,25 +516,17 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t results: list[thread_result] = [] thread_pool: comfy.multigpu.MultiGPUThreadPool = model_options.get("multigpu_thread_pool") - main_device = output_device - main_batch_tuple = None - # Submit extra GPU work to pool first, then run main device on this thread + # Submit all GPU work to pool threads pool_devices = [] for device, batch_tuple in device_batched_hooked_to_run.items(): - if device == main_device and thread_pool is not None: - main_batch_tuple = batch_tuple - elif thread_pool is not None: + if thread_pool is not None: thread_pool.submit(device, _handle_batch_pooled, device, batch_tuple) pool_devices.append(device) else: # Fallback: no pool, run everything on main thread _handle_batch(device, batch_tuple, results) - # Run main device batch on this thread (parallel with pool workers) - if main_batch_tuple is not None: - _handle_batch(main_device, main_batch_tuple, results) - # Collect results from pool workers for device in pool_devices: worker_results, error = thread_pool.get_result(device) @@ -1210,10 +1202,11 @@ class CFGGuider: multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options) - # Create persistent thread pool for extra GPU devices + # Create persistent thread pool for all GPU devices (main + extras) if multigpu_patchers: extra_devices = [p.load_device for p in multigpu_patchers] - self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(extra_devices) + all_devices = [device] + extra_devices + self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices) try: noise = noise.to(device=device, dtype=torch.float32) From f0d550bd02bc0f7550cad113eca852cdf5c805c6 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 16 Apr 2026 15:49:01 +1000 Subject: [PATCH 52/52] Minor updates for worksplit_gpu with comfy-aimdo (#13419) * main: init all visible cuda devices in aimdo * mp: call vbars_analyze for the GPU in question * requirements: bump aimdo to pre-release version --- comfy/model_patcher.py | 3 ++- main.py | 4 ++-- requirements.txt | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index c3ecc276f..a74a51902 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -319,7 +319,8 @@ class ModelPatcher: #than pays for CFG. So return everything both torch and Aimdo could give us aimdo_mem = 0 if comfy.memory_management.aimdo_enabled: - aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze() + aimdo_device = device.index if getattr(device, "type", None) == "cuda" else None + aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze(aimdo_device) return comfy.model_management.get_free_memory(device) + aimdo_mem def get_clone_model_override(self): diff --git a/main.py b/main.py index 12b04719d..de145a1e9 100644 --- a/main.py +++ b/main.py @@ -192,7 +192,7 @@ import gc if 'torch' in sys.modules: logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") - +import torch import comfy.utils import execution @@ -210,7 +210,7 @@ import comfy.model_patcher if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()): if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)): logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") - elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): + elif comfy_aimdo.control.init_devices(range(torch.cuda.device_count())): if args.verbose == 'DEBUG': comfy_aimdo.control.set_log_debug() elif args.verbose == 'CRITICAL': diff --git a/requirements.txt b/requirements.txt index 1a8e1ea1c..c60219a88 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,7 @@ SQLAlchemy filelock av>=14.2.0 comfy-kitchen>=0.2.8 -comfy-aimdo>=0.2.12 +comfy-aimdo==0.0.213 requests simpleeval>=1.0.0 blake3