ComfyUI/comfy/sampler_helpers.py
Rattus 9d9e0c5631 make control-net load order deterministic
Make this deterministic so speeds dont change base of load order. Load
them in reverse order so whatever the caller lists first is the top
priority.
2026-05-04 23:49:33 +10:00

204 lines
8.3 KiB
Python

from __future__ import annotations
import uuid
import math
import collections
import comfy.model_management
import comfy.conds
import comfy.utils
import comfy.hooks
import comfy.patcher_extension
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
from comfy.controlnet import ControlBase
def prepare_mask(noise_mask, shape, device):
return comfy.utils.reshape_mask(noise_mask, shape).to(device)
def get_models_from_cond(cond, model_type):
models = []
for c in cond:
if model_type in c:
if isinstance(c[model_type], list):
models += c[model_type]
else:
models += [c[model_type]]
return models
def get_hooks_from_cond(cond, full_hooks: comfy.hooks.HookGroup):
# get hooks from conds, and collect cnets so they can be checked for extra_hooks
cnets: list[ControlBase] = []
for c in cond:
if 'hooks' in c:
for hook in c['hooks'].hooks:
full_hooks.add(hook)
if 'control' in c:
cnets.append(c['control'])
def get_extra_hooks_from_cnet(cnet: ControlBase, _list: list):
if cnet.extra_hooks is not None:
_list.append(cnet.extra_hooks)
if cnet.previous_controlnet is None:
return _list
return get_extra_hooks_from_cnet(cnet.previous_controlnet, _list)
hooks_list = []
cnets = set(cnets)
for base_cnet in cnets:
get_extra_hooks_from_cnet(base_cnet, hooks_list)
extra_hooks = comfy.hooks.HookGroup.combine_all_hooks(hooks_list)
if extra_hooks is not None:
for hook in extra_hooks.hooks:
full_hooks.add(hook)
return full_hooks
def convert_cond(cond):
out = []
for c in cond:
temp = c[1].copy()
model_conds = temp.get("model_conds", {})
if c[0] is not None:
temp["cross_attn"] = c[0]
temp["model_conds"] = model_conds
temp["uuid"] = uuid.uuid4()
out.append(temp)
return out
def cond_has_hooks(cond):
for c in cond:
temp = c[1]
if "hooks" in temp:
return True
if "control" in temp:
control = temp["control"]
extra_hooks = control.get_extra_hooks()
if len(extra_hooks) > 0:
return True
return False
def get_additional_models(conds, dtype):
"""loads additional models in conditioning"""
cnets: list[ControlBase] = []
gligen = []
add_models = []
for k in conds:
cnets += get_models_from_cond(conds[k], "control")
gligen += get_models_from_cond(conds[k], "gligen")
add_models += get_models_from_cond(conds[k], "additional_models")
# Order-preserving dedup. A plain set() would randomize iteration order across runs
control_nets = list(dict.fromkeys(cnets))
inference_memory = 0
control_models = []
for m in control_nets:
control_models += m.get_models()
inference_memory += m.inference_memory_requirements(dtype)
gligen = [x[1] for x in gligen]
models = control_models + gligen + add_models
return models, inference_memory
def get_additional_models_from_model_options(model_options: dict[str]=None):
"""loads additional models from registered AddModels hooks"""
models = []
if model_options is not None and "registered_hooks" in model_options:
registered: comfy.hooks.HookGroup = model_options["registered_hooks"]
for hook in registered.get_type(comfy.hooks.EnumHookType.AdditionalModels):
hook: comfy.hooks.AdditionalModelsHook
models.extend(hook.models)
return models
def cleanup_additional_models(models):
"""cleanup additional models that were loaded"""
for m in models:
if hasattr(m, 'cleanup'):
m.cleanup()
def estimate_memory(model, noise_shape, conds):
cond_shapes = collections.defaultdict(list)
cond_shapes_min = {}
for _, cs in conds.items():
for cond in cs:
for k, v in model.model.extra_conds_shapes(**cond).items():
cond_shapes[k].append(v)
if cond_shapes_min.get(k, None) is None:
cond_shapes_min[k] = [v]
elif math.prod(v) > math.prod(cond_shapes_min[k][0]):
cond_shapes_min[k] = [v]
memory_required = model.model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:]), cond_shapes=cond_shapes)
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
return memory_required, minimum_memory_required
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
_prepare_sampling,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
)
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
real_model: BaseModel = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
if force_offload: # In training + offload enabled, we want to force prepare sampling to trigger partial load
memory_required = 1e20
minimum_memory_required = None
else:
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
memory_required += inference_memory
minimum_memory_required += inference_memory
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
real_model = model.model
return real_model, conds, models
def cleanup_models(conds, models):
cleanup_additional_models(models)
control_cleanup = []
for k in conds:
control_cleanup += get_models_from_cond(conds[k], "control")
cleanup_additional_models(set(control_cleanup))
def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
'''
Registers hooks from conds.
'''
# check for hooks in conds - if not registered, see if can be applied
hooks = comfy.hooks.HookGroup()
for k in conds:
get_hooks_from_cond(conds[k], hooks)
# add wrappers and callbacks from ModelPatcher to transformer_options
comfy.patcher_extension.merge_nested_dicts(model_options["transformer_options"].setdefault("wrappers", {}), model.wrappers, copy_dict1=False)
comfy.patcher_extension.merge_nested_dicts(model_options["transformer_options"].setdefault("callbacks", {}), model.callbacks, copy_dict1=False)
# begin registering hooks
registered = comfy.hooks.HookGroup()
target_dict = comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model)
# handle all TransformerOptionsHooks
for hook in hooks.get_type(comfy.hooks.EnumHookType.TransformerOptions):
hook: comfy.hooks.TransformerOptionsHook
hook.add_hook_patches(model, model_options, target_dict, registered)
# handle all AddModelsHooks
for hook in hooks.get_type(comfy.hooks.EnumHookType.AdditionalModels):
hook: comfy.hooks.AdditionalModelsHook
hook.add_hook_patches(model, model_options, target_dict, registered)
# handle all WeightHooks by registering on ModelPatcher
model.register_all_hook_patches(hooks, target_dict, model_options, registered)
# add registered_hooks onto model_options for further reference
if len(registered) > 0:
model_options["registered_hooks"] = registered
# merge original wrappers and callbacks with hooked wrappers and callbacks
to_load_options: dict[str] = model_options.setdefault("to_load_options", {})
for wc_name in ["wrappers", "callbacks"]:
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
copy_dict1=False)
return to_load_options