Compare commits

...

10 Commits

Author SHA1 Message Date
rattus
b177cc4225
Merge 86e74e7f8b into 5943fbf457 2026-01-09 01:29:12 +09:00
Rattus
86e74e7f8b nodes: add cache barriers to models / clip 2025-12-19 22:35:25 +10:00
Rattus
783da446c1 comfy_execution: add cache barriers
Add a system where an input is marked as a cache barrier, deferring its
evaluation. Once the node is executed, the barrier is released and
everything behind the barrier is executed at increase priority.
2025-12-19 22:35:25 +10:00
Rattus
96ad4904fe mm: fix debug message 2025-12-19 19:32:51 +10:00
Rattus
4bb34b85b7 mm: make model offloading deffered with weakrefs
RAMPressure caching may ned to purge the same model that you are
currently trying to offload for VRAM freeing. In this case, RAMPressure
cache takes priority and needs to be able to pull the trigger on dumping
the whole model and freeing the ModelPatcher in question. To do this,
defer the actual tranfer of model weights from GPU to RAM to
model_management state and not as part of ModelPatcher. This is dones as
a list of weakrefs.

If RAM cache decides to free to model you are currently unloading, then
the ModelPatcher and refs simply dissappear in the middle of the
unloading process, and both RAM and VRAM will be freed.

The unpatcher now queues the individual leaf modules to be offloaded
one-by-one so that RAM levels can be monitored.

Note that the UnloadPartially that is potentially done as part of a
load will not be freeable this way, however it shouldn't be anyway as
that is the currently active model and RAM cache cannot save you if
you cant even fit the one model you are currently trying to use.
2025-12-19 19:32:51 +10:00
Rattus
2c86040cf7 mm: dont use list of indexes for unload list work list
This is currently put together as a list of indexes assuming the
current_loaded_models  doesn't change. However we might need to pruge a
model as part of the offload process which means this list can change in
the middle of the freeing process. handle by taking independent refs to
the LoadedModel objects and dong safe by-value deletion of
current_loaded_models.
2025-12-19 19:32:51 +10:00
Rattus
abe39647ee mm: make garbage collector null safe on real_model
currently this hard assumes that the caller of model_unload will keep
current_loaded_models in sync. With RAMPressureCache its possible to
have the garbage collector occur in the middle of the model free process
which can split these two steps.
2025-12-19 19:32:51 +10:00
Rattus
3f4ee9174c sd: Free RAM on main model load 2025-12-19 19:32:51 +10:00
Rattus
f190744f62 mm: Add free_ram()
Add the free_ram() API and a means to install implementations of the
freer (I.E. the RAM cache).
2025-12-19 19:32:51 +10:00
Rattus
4612aab281 caching: build headroom into the RAM cache
move the headroom logic into the RAM cache to make this a little easier
to call to "free me some RAM".

Rename the API to free_ram().

Split off the clean_list creation to a completely separate function to
avoid any stray strong reference to the content-to-be-freed on the
stack.
2025-12-19 19:32:51 +10:00
9 changed files with 156 additions and 59 deletions

View File

@ -306,6 +306,7 @@ class BaseModel(torch.nn.Module):
to_load[k[len(unet_prefix):]] = sd.pop(k)
to_load = self.model_config.process_unet_state_dict(to_load)
comfy.model_management.free_ram(state_dict=to_load)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
if len(m) > 0:
logging.warning("unet missing: {}".format(m))

View File

@ -448,6 +448,20 @@ try:
except:
logging.warning("Could not pick default device.")
current_ram_listeners = set()
def register_ram_listener(listener):
current_ram_listeners.add(listener)
def unregister_ram_listener(listener):
current_ram_listeners.discard(listener)
def free_ram(extra_ram=0, state_dict={}):
for tensor in state_dict.values():
if isinstance(tensor, torch.Tensor):
extra_ram += tensor.numel() * tensor.element_size()
for listener in current_ram_listeners:
listener.free_ram(extra_ram)
current_loaded_models = []
@ -524,12 +538,18 @@ class LoadedModel:
return False
def model_unload(self, memory_to_free=None, unpatch_weights=True):
if self.model is None:
return True
logging.debug(f"Unloading {self.model.model.__class__.__name__}")
if memory_to_free is not None:
if memory_to_free < self.model.loaded_size():
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
freed, modules_to_offload = self.model.partially_unload(self.model.offload_device, memory_to_free)
offload_modules(modules_to_offload, self.model.offload_device)
if freed >= memory_to_free:
return False
self.model.detach(unpatch_weights)
if self.model is not None:
modules_to_offload = self.model.detach(unpatch_weights)
offload_modules(modules_to_offload, self.model.offload_device)
self.model_finalizer.detach()
self.model_finalizer = None
self.real_model = None
@ -546,7 +566,7 @@ class LoadedModel:
self._patcher_finalizer.detach()
def is_dead(self):
return self.real_model() is not None and self.model is None
return self.real_model is not None and self.real_model() is not None and self.model is None
def use_more_memory(extra_memory, loaded_models, device):
@ -581,6 +601,13 @@ def extra_reserved_memory():
def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
def offload_modules(modules, offload_device):
for module in modules:
if module() is None:
continue
module().to(offload_device)
free_ram()
def free_memory(memory_required, device, keep_loaded=[]):
cleanup_models_gc()
unloaded_model = []
@ -591,23 +618,25 @@ def free_memory(memory_required, device, keep_loaded=[]):
shift_model = current_loaded_models[i]
if shift_model.device == device:
if shift_model not in keep_loaded and not shift_model.is_dead():
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i, shift_model))
shift_model.currently_used = False
for x in sorted(can_unload):
i = x[-1]
shift_model = x[-1]
i = x[-2]
memory_to_free = None
if not DISABLE_SMART_MEMORY:
free_mem = get_free_memory(device)
if free_mem > memory_required:
break
memory_to_free = memory_required - free_mem
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
if current_loaded_models[i].model_unload(memory_to_free):
unloaded_model.append(i)
if shift_model.model_unload(memory_to_free):
unloaded_model.append((i, shift_model))
for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i))
for i, shift_model in sorted(unloaded_model, reverse=True):
unloaded_models.append(shift_model)
if shift_model in current_loaded_models:
current_loaded_models.remove(shift_model)
if len(unloaded_model) > 0:
soft_empty_cache()
@ -742,7 +771,7 @@ def cleanup_models_gc():
def cleanup_models():
to_delete = []
for i in range(len(current_loaded_models)):
if current_loaded_models[i].real_model() is None:
if current_loaded_models[i].real_model is None or current_loaded_models[i].real_model() is None:
to_delete = [i] + to_delete
for i in to_delete:

View File

@ -24,6 +24,7 @@ import inspect
import logging
import math
import uuid
import weakref
from typing import Callable, Optional
import torch
@ -832,6 +833,7 @@ class ModelPatcher:
def unpatch_model(self, device_to=None, unpatch_weights=True):
self.eject_model()
modules_to_move = []
if unpatch_weights:
self.unpatch_hooks()
self.unpin_all_weights()
@ -856,7 +858,8 @@ class ModelPatcher:
self.backup.clear()
if device_to is not None:
self.model.to(device_to)
modules_to_move = [ weakref.ref(m[3]) for m in self._load_list() ]
modules_to_move.append(weakref.ref(self.model))
self.model.device = device_to
self.model.model_loaded_weight_memory = 0
self.model.model_offload_buffer_memory = 0
@ -870,12 +873,14 @@ class ModelPatcher:
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
self.object_patches_backup.clear()
return modules_to_move
def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
with self.use_ejected():
hooks_unpatched = False
memory_freed = 0
patch_counter = 0
modules_to_move = []
unload_list = self._load_list()
unload_list.sort()
@ -916,7 +921,7 @@ class ModelPatcher:
bias_key = "{}.bias".format(n)
if move_weight:
cast_weight = self.force_cast_weights
m.to(device_to)
modules_to_move.append(weakref.ref(m))
module_mem += move_weight_functions(m, device_to)
if lowvram_possible:
if weight_key in self.patches:
@ -954,20 +959,22 @@ class ModelPatcher:
self.model.model_loaded_weight_memory -= memory_freed
self.model.model_offload_buffer_memory = offload_buffer
logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
return memory_freed
return memory_freed, modules_to_move
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
with self.use_ejected(skip_and_inject_on_exit_only=True):
unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights)
# TODO: force_patch_weights should not unload + reload full model
used = self.model.model_loaded_weight_memory
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights)
modules_to_offload = self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights)
comfy.model_management.offload_modules(modules_to_offload, self.offload_device)
if unpatch_weights:
extra_memory += (used - self.model.model_loaded_weight_memory)
self.patch_model(load_weights=False)
if extra_memory < 0 and not unpatch_weights:
self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
_, modules_to_offload = self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
comfy.model_management.offload_modules(modules_to_offload, self.offload_device)
return 0
full_load = False
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
@ -979,7 +986,7 @@ class ModelPatcher:
try:
self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load)
except Exception as e:
self.detach()
comfy.model_management.offload_modules(self.detach(), self.offload_device())
raise e
return self.model.model_loaded_weight_memory - current_used
@ -987,11 +994,12 @@ class ModelPatcher:
def detach(self, unpatch_all=True):
self.eject_model()
self.model_patches_to(self.offload_device)
modules_to_offload = []
if unpatch_all:
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
modules_to_offload = self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH):
callback(self, unpatch_all)
return self.model
return modules_to_offload
def current_loaded_device(self):
return self.model.device

View File

@ -288,6 +288,7 @@ class CLIP:
def load_sd(self, sd, full_model=False):
if full_model:
comfy.model_management.free_ram(state_dict=sd)
return self.cond_stage_model.load_state_dict(sd, strict=False)
else:
return self.cond_stage_model.load_sd(sd)
@ -665,6 +666,7 @@ class VAE:
self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval()
comfy.model_management.free_ram(state_dict=sd)
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
if len(m) > 0:
logging.warning("Missing VAE keys {}".format(m))
@ -986,6 +988,7 @@ def load_style_model(ckpt_path):
model = comfy.ldm.flux.redux.ReduxImageEncoder()
else:
raise Exception("invalid style model {}".format(ckpt_path))
comfy.model_management.free_ram(state_dict=model_data)
model.load_state_dict(model_data)
return StyleModel(model)

View File

@ -193,7 +193,7 @@ class BasicCache:
self._clean_cache()
self._clean_subcaches()
def poll(self, **kwargs):
def free_ram(self, *args, **kwargs):
pass
def _set_immediate(self, node_id, value):
@ -284,7 +284,7 @@ class NullCache:
def clean_unused(self):
pass
def poll(self, **kwargs):
def free_ram(self, *args, **kwargs):
pass
def get(self, node_id):
@ -366,9 +366,10 @@ RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
class RAMPressureCache(LRUCache):
def __init__(self, key_class):
def __init__(self, key_class, min_headroom=4.0):
super().__init__(key_class, 0)
self.timestamps = {}
self.min_headroom = min_headroom
def clean_unused(self):
self._clean_subcaches()
@ -381,19 +382,10 @@ class RAMPressureCache(LRUCache):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
return super().get(node_id)
def poll(self, ram_headroom):
def _ram_gb():
return psutil.virtual_memory().available / (1024**3)
if _ram_gb() > ram_headroom:
return
gc.collect()
if _ram_gb() > ram_headroom:
return
def _build_clean_list(self):
clean_list = []
for key, (outputs, _), in self.cache.items():
for key, (_, outputs), in self.cache.items():
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
@ -416,8 +408,22 @@ class RAMPressureCache(LRUCache):
#In the case where we have no information on the node ram usage at all,
#break OOM score ties on the last touch timestamp (pure LRU)
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
return clean_list
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
def free_ram(self, extra_ram=0):
headroom_target = self.min_headroom + (extra_ram / (1024**3))
def _ram_gb():
return psutil.virtual_memory().available / (1024**3)
if _ram_gb() > headroom_target:
return
gc.collect()
if _ram_gb() > headroom_target:
return
clean_list = self._build_clean_list()
while _ram_gb() < headroom_target * RAM_CACHE_HYSTERESIS and clean_list:
_, _, key = clean_list.pop()
del self.cache[key]
gc.collect()

View File

@ -112,6 +112,8 @@ class TopologicalSort:
self.blocking = {} # Which nodes are blocked by this node
self.externalBlocks = 0
self.unblockedEvent = asyncio.Event()
self.priorities = {}
self.barrierNodes = set()
def get_input_info(self, unique_id, input_name):
class_type = self.dynprompt.get_node(unique_id)["class_type"]
@ -130,13 +132,37 @@ class TopologicalSort:
def add_strong_link(self, from_node_id, from_socket, to_node_id):
if not self.is_cached(from_node_id):
self.add_node(from_node_id)
self.add_node(from_node_id, priority=self.priorities.get(to_node_id, 0))
if to_node_id not in self.blocking[from_node_id]:
self.blocking[from_node_id][to_node_id] = {}
self.blockCount[to_node_id] += 1
self.blocking[from_node_id][to_node_id][from_socket] = True
def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None):
def is_barrier(self, node_id):
return node_id in self.barrierNodes
def unbarrier(self, node_id):
if not node_id in self.barrierNodes:
return
self.barrierNodes.remove(node_id)
self.priorities[node_id] = self.priorities.get(node_id, 0) + 1
links = []
inputs = self.dynprompt.get_node(node_id)["inputs"]
for input_name in inputs:
value = inputs[input_name]
if is_link(value):
from_node_id, from_socket = value
_, _, input_info = self.get_input_info(node_id, input_name)
is_barrier = input_info is not None and "cache-barrier" in input_info and input_info["cache-barrier"]
if is_barrier:
links.append((from_node_id, from_socket, node_id))
for link in links:
self.add_strong_link(*link)
def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None, priority=0):
node_ids = [node_unique_id]
links = []
@ -148,6 +174,7 @@ class TopologicalSort:
self.pendingNodes[unique_id] = True
self.blockCount[unique_id] = 0
self.blocking[unique_id] = {}
self.priorities[unique_id] = priority
inputs = self.dynprompt.get_node(unique_id)["inputs"]
for input_name in inputs:
@ -158,10 +185,13 @@ class TopologicalSort:
continue
_, _, input_info = self.get_input_info(unique_id, input_name)
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
if (include_lazy or not is_lazy):
is_barrier = input_info is not None and "cache-barrier" in input_info and input_info["cache-barrier"]
if (include_lazy or (not is_lazy and not is_barrier)):
if not self.is_cached(from_node_id):
node_ids.append(from_node_id)
links.append((from_node_id, from_socket, unique_id))
if is_barrier:
self.barrierNodes.add(unique_id)
for link in links:
self.add_strong_link(*link)
@ -180,7 +210,7 @@ class TopologicalSort:
return False
def get_ready_nodes(self):
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
return [(self.priorities.get(node_id, 0), node_id) for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
def pop_node(self, unique_id):
del self.pendingNodes[unique_id]
@ -286,25 +316,34 @@ class ExecutionList(TopologicalSort):
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION))
for node_id in node_list:
priority_level = 0
priority_nodes = []
for (priority, node_id) in node_list:
if priority > priority_level:
priority_level = priority
priority_nodes = []
if priority == priority_level:
priority_nodes.append(node_id)
for node_id in priority_nodes:
if is_output(node_id) or is_async(node_id):
return node_id
#This should handle the VAEDecode -> preview case
for node_id in node_list:
for node_id in priority_nodes:
for blocked_node_id in self.blocking[node_id]:
if is_output(blocked_node_id):
return node_id
#This should handle the VAELoader -> VAEDecode -> preview case
for node_id in node_list:
for node_id in priority_nodes:
for blocked_node_id in self.blocking[node_id]:
for blocked_node_id1 in self.blocking[blocked_node_id]:
if is_output(blocked_node_id1):
return node_id
#TODO: this function should be improved
return node_list[0]
return priority_nodes[0]
def unstage_node_execution(self):
assert self.staged_node_id is not None

View File

@ -19,7 +19,7 @@ class BasicScheduler(io.ComfyNode):
node_id="BasicScheduler",
category="sampling/custom_sampling/schedulers",
inputs=[
io.Model.Input("model"),
io.Model.Input("model", extra_dict={"cache-barrier":True}),
io.Combo.Input("scheduler", options=comfy.samplers.SCHEDULER_NAMES),
io.Int.Input("steps", default=20, min=1, max=10000),
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
@ -138,7 +138,7 @@ class SDTurboScheduler(io.ComfyNode):
node_id="SDTurboScheduler",
category="sampling/custom_sampling/schedulers",
inputs=[
io.Model.Input("model"),
io.Model.Input("model", extra_dict={"cache-barrier":True}),
io.Int.Input("steps", default=1, min=1, max=10),
io.Float.Input("denoise", default=1.0, min=0, max=1.0, step=0.01),
],
@ -162,7 +162,7 @@ class BetaSamplingScheduler(io.ComfyNode):
node_id="BetaSamplingScheduler",
category="sampling/custom_sampling/schedulers",
inputs=[
io.Model.Input("model"),
io.Model.Input("model", extra_dict={"cache-barrier":True}),
io.Int.Input("steps", default=20, min=1, max=10000),
io.Float.Input("alpha", default=0.6, min=0.0, max=50.0, step=0.01, round=False),
io.Float.Input("beta", default=0.6, min=0.0, max=50.0, step=0.01, round=False),
@ -352,7 +352,7 @@ class SamplingPercentToSigma(io.ComfyNode):
node_id="SamplingPercentToSigma",
category="sampling/custom_sampling/sigmas",
inputs=[
io.Model.Input("model"),
io.Model.Input("model", extra_dict={"cache-barrier":True}),
io.Float.Input("sampling_percent", default=0.0, min=0.0, max=1.0, step=0.0001),
io.Boolean.Input("return_actual_sigma", default=False, tooltip="Return the actual sigma value instead of the value used for interval checks.\nThis only affects results at 0.0 and 1.0."),
],
@ -623,7 +623,7 @@ class SamplerSASolver(io.ComfyNode):
node_id="SamplerSASolver",
category="sampling/custom_sampling/samplers",
inputs=[
io.Model.Input("model"),
io.Model.Input("model", extra_dict={"cache-barrier":True}),
io.Float.Input("eta", default=1.0, min=0.0, max=10.0, step=0.01, round=False),
io.Float.Input("sde_start_percent", default=0.2, min=0.0, max=1.0, step=0.001),
io.Float.Input("sde_end_percent", default=0.8, min=0.0, max=1.0, step=0.001),
@ -719,7 +719,7 @@ class SamplerCustom(io.ComfyNode):
node_id="SamplerCustom",
category="sampling/custom_sampling",
inputs=[
io.Model.Input("model"),
io.Model.Input("model", extra_dict={"cache-barrier":True}),
io.Boolean.Input("add_noise", default=True),
io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
@ -784,7 +784,7 @@ class BasicGuider(io.ComfyNode):
node_id="BasicGuider",
category="sampling/custom_sampling/guiders",
inputs=[
io.Model.Input("model"),
io.Model.Input("model", extra_dict={"cache-barrier":True}),
io.Conditioning.Input("conditioning"),
],
outputs=[io.Guider.Output()]
@ -805,7 +805,7 @@ class CFGGuider(io.ComfyNode):
node_id="CFGGuider",
category="sampling/custom_sampling/guiders",
inputs=[
io.Model.Input("model"),
io.Model.Input("model", extra_dict={"cache-barrier":True}),
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
@ -858,7 +858,7 @@ class DualCFGGuider(io.ComfyNode):
node_id="DualCFGGuider",
category="sampling/custom_sampling/guiders",
inputs=[
io.Model.Input("model"),
io.Model.Input("model", extra_dict={"cache-barrier":True}),
io.Conditioning.Input("cond1"),
io.Conditioning.Input("cond2"),
io.Conditioning.Input("negative"),
@ -973,7 +973,7 @@ class AddNoise(io.ComfyNode):
category="_for_testing/custom_sampling/noise",
is_experimental=True,
inputs=[
io.Model.Input("model"),
io.Model.Input("model", extra_dict={"cache-barrier":True}),
io.Noise.Input("noise"),
io.Sigmas.Input("sigmas"),
io.Latent.Input("latent_image"),

View File

@ -108,7 +108,7 @@ class CacheSet:
self.init_null_cache()
logging.info("Disabling intermediate node cache.")
elif cache_type == CacheType.RAM_PRESSURE:
cache_ram = cache_args.get("ram", 16.0)
cache_ram = cache_args.get("ram", 4.0)
self.init_ram_cache(cache_ram)
logging.info("Using RAM pressure cache.")
elif cache_type == CacheType.LRU:
@ -130,7 +130,7 @@ class CacheSet:
self.objects = HierarchicalCache(CacheKeySetID)
def init_ram_cache(self, min_headroom):
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
self.outputs = RAMPressureCache(CacheKeySetInputSignature, min_headroom)
self.objects = HierarchicalCache(CacheKeySetID)
def init_null_cache(self):
@ -427,7 +427,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
input_data_all = None
try:
if unique_id in pending_async_nodes:
if execution_list.is_barrier(unique_id):
execution_list.unbarrier(unique_id)
return (ExecutionResult.PENDING, None, None)
elif unique_id in pending_async_nodes:
results = []
for r in pending_async_nodes[unique_id]:
if isinstance(r, asyncio.Task):
@ -622,13 +625,21 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
class PromptExecutor:
def __init__(self, server, cache_type=False, cache_args=None):
self.caches = None
self.cache_args = cache_args
self.cache_type = cache_type
self.server = server
self.reset()
def reset(self):
if self.caches is not None:
for cache in self.caches.all:
comfy.model_management.unregister_ram_listener(cache)
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
for cache in self.caches.all:
comfy.model_management.register_ram_listener(cache)
self.status_messages = []
self.success = True
@ -728,7 +739,7 @@ class PromptExecutor:
execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution()
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
self.caches.outputs.free_ram()
else:
# Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)

View File

@ -60,7 +60,7 @@ class CLIPTextEncode(ComfyNodeABC):
return {
"required": {
"text": (IO.STRING, {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}),
"clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."})
"clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text.", "cache-barrier" : True})
}
}
RETURN_TYPES = (IO.CONDITIONING,)
@ -1518,7 +1518,7 @@ class KSampler:
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", {"tooltip": "The model used for denoising the input latent."}),
"model": ("MODEL", {"tooltip": "The model used for denoising the input latent.", "cache-barrier": True}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000, "tooltip": "The number of steps used in the denoising process."}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01, "tooltip": "The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."}),
@ -1545,7 +1545,7 @@ class KSamplerAdvanced:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
{"model": ("MODEL", {"cache-barrier": True}),
"add_noise": (["enable", "disable"], ),
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True}),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),