diff --git a/comfy/model_base.py b/comfy/model_base.py index 49efd700b..1e8ff51af 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -306,6 +306,7 @@ class BaseModel(torch.nn.Module): to_load[k[len(unet_prefix):]] = sd.pop(k) to_load = self.model_config.process_unet_state_dict(to_load) + comfy.model_management.free_ram(state_dict=to_load) m, u = self.diffusion_model.load_state_dict(to_load, strict=False) if len(m) > 0: logging.warning("unet missing: {}".format(m)) diff --git a/comfy/model_management.py b/comfy/model_management.py index 928282092..bc8e9b662 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -448,6 +448,20 @@ try: except: logging.warning("Could not pick default device.") +current_ram_listeners = set() + +def register_ram_listener(listener): + current_ram_listeners.add(listener) + +def unregister_ram_listener(listener): + current_ram_listeners.discard(listener) + +def free_ram(extra_ram=0, state_dict={}): + for tensor in state_dict.values(): + if isinstance(tensor, torch.Tensor): + extra_ram += tensor.numel() * tensor.element_size() + for listener in current_ram_listeners: + listener.free_ram(extra_ram) current_loaded_models = [] @@ -524,12 +538,18 @@ class LoadedModel: return False def model_unload(self, memory_to_free=None, unpatch_weights=True): + if self.model is None: + return True + logging.debug(f"Unloading {self.model.model.__class__.__name__}") if memory_to_free is not None: if memory_to_free < self.model.loaded_size(): - freed = self.model.partially_unload(self.model.offload_device, memory_to_free) + freed, modules_to_offload = self.model.partially_unload(self.model.offload_device, memory_to_free) + offload_modules(modules_to_offload, self.model.offload_device) if freed >= memory_to_free: return False - self.model.detach(unpatch_weights) + if self.model is not None: + modules_to_offload = self.model.detach(unpatch_weights) + offload_modules(modules_to_offload, self.model.offload_device) self.model_finalizer.detach() self.model_finalizer = None self.real_model = None @@ -546,7 +566,7 @@ class LoadedModel: self._patcher_finalizer.detach() def is_dead(self): - return self.real_model() is not None and self.model is None + return self.real_model is not None and self.real_model() is not None and self.model is None def use_more_memory(extra_memory, loaded_models, device): @@ -581,6 +601,13 @@ def extra_reserved_memory(): def minimum_inference_memory(): return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() +def offload_modules(modules, offload_device): + for module in modules: + if module() is None: + continue + module().to(offload_device) + free_ram() + def free_memory(memory_required, device, keep_loaded=[]): cleanup_models_gc() unloaded_model = [] @@ -591,23 +618,25 @@ def free_memory(memory_required, device, keep_loaded=[]): shift_model = current_loaded_models[i] if shift_model.device == device: if shift_model not in keep_loaded and not shift_model.is_dead(): - can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) + can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i, shift_model)) shift_model.currently_used = False for x in sorted(can_unload): - i = x[-1] + shift_model = x[-1] + i = x[-2] memory_to_free = None if not DISABLE_SMART_MEMORY: free_mem = get_free_memory(device) if free_mem > memory_required: break memory_to_free = memory_required - free_mem - logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") - if current_loaded_models[i].model_unload(memory_to_free): - unloaded_model.append(i) + if shift_model.model_unload(memory_to_free): + unloaded_model.append((i, shift_model)) - for i in sorted(unloaded_model, reverse=True): - unloaded_models.append(current_loaded_models.pop(i)) + for i, shift_model in sorted(unloaded_model, reverse=True): + unloaded_models.append(shift_model) + if shift_model in current_loaded_models: + current_loaded_models.remove(shift_model) if len(unloaded_model) > 0: soft_empty_cache() @@ -742,7 +771,7 @@ def cleanup_models_gc(): def cleanup_models(): to_delete = [] for i in range(len(current_loaded_models)): - if current_loaded_models[i].real_model() is None: + if current_loaded_models[i].real_model is None or current_loaded_models[i].real_model() is None: to_delete = [i] + to_delete for i in to_delete: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index f6b80a40f..cba2f9bf0 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -24,6 +24,7 @@ import inspect import logging import math import uuid +import weakref from typing import Callable, Optional import torch @@ -832,6 +833,7 @@ class ModelPatcher: def unpatch_model(self, device_to=None, unpatch_weights=True): self.eject_model() + modules_to_move = [] if unpatch_weights: self.unpatch_hooks() self.unpin_all_weights() @@ -856,7 +858,8 @@ class ModelPatcher: self.backup.clear() if device_to is not None: - self.model.to(device_to) + modules_to_move = [ weakref.ref(m[3]) for m in self._load_list() ] + modules_to_move.append(weakref.ref(self.model)) self.model.device = device_to self.model.model_loaded_weight_memory = 0 self.model.model_offload_buffer_memory = 0 @@ -870,12 +873,14 @@ class ModelPatcher: comfy.utils.set_attr(self.model, k, self.object_patches_backup[k]) self.object_patches_backup.clear() + return modules_to_move def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False): with self.use_ejected(): hooks_unpatched = False memory_freed = 0 patch_counter = 0 + modules_to_move = [] unload_list = self._load_list() unload_list.sort() @@ -916,7 +921,7 @@ class ModelPatcher: bias_key = "{}.bias".format(n) if move_weight: cast_weight = self.force_cast_weights - m.to(device_to) + modules_to_move.append(weakref.ref(m)) module_mem += move_weight_functions(m, device_to) if lowvram_possible: if weight_key in self.patches: @@ -954,20 +959,22 @@ class ModelPatcher: self.model.model_loaded_weight_memory -= memory_freed self.model.model_offload_buffer_memory = offload_buffer logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter)) - return memory_freed + return memory_freed, modules_to_move def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): with self.use_ejected(skip_and_inject_on_exit_only=True): unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights) # TODO: force_patch_weights should not unload + reload full model used = self.model.model_loaded_weight_memory - self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights) + modules_to_offload = self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights) + comfy.model_management.offload_modules(modules_to_offload, self.offload_device) if unpatch_weights: extra_memory += (used - self.model.model_loaded_weight_memory) self.patch_model(load_weights=False) if extra_memory < 0 and not unpatch_weights: - self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights) + _, modules_to_offload = self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights) + comfy.model_management.offload_modules(modules_to_offload, self.offload_device) return 0 full_load = False if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0: @@ -979,7 +986,7 @@ class ModelPatcher: try: self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load) except Exception as e: - self.detach() + comfy.model_management.offload_modules(self.detach(), self.offload_device()) raise e return self.model.model_loaded_weight_memory - current_used @@ -987,11 +994,12 @@ class ModelPatcher: def detach(self, unpatch_all=True): self.eject_model() self.model_patches_to(self.offload_device) + modules_to_offload = [] if unpatch_all: - self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all) + modules_to_offload = self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all) for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH): callback(self, unpatch_all) - return self.model + return modules_to_offload def current_loaded_device(self): return self.model.device diff --git a/comfy/sd.py b/comfy/sd.py index 5a7221620..f359e242f 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -288,6 +288,7 @@ class CLIP: def load_sd(self, sd, full_model=False): if full_model: + comfy.model_management.free_ram(state_dict=sd) return self.cond_stage_model.load_state_dict(sd, strict=False) else: return self.cond_stage_model.load_sd(sd) @@ -665,6 +666,7 @@ class VAE: self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() + comfy.model_management.free_ram(state_dict=sd) m, u = self.first_stage_model.load_state_dict(sd, strict=False) if len(m) > 0: logging.warning("Missing VAE keys {}".format(m)) @@ -986,6 +988,7 @@ def load_style_model(ckpt_path): model = comfy.ldm.flux.redux.ReduxImageEncoder() else: raise Exception("invalid style model {}".format(ckpt_path)) + comfy.model_management.free_ram(state_dict=model_data) model.load_state_dict(model_data) return StyleModel(model) diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 326a279fc..43f882469 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -193,7 +193,7 @@ class BasicCache: self._clean_cache() self._clean_subcaches() - def poll(self, **kwargs): + def free_ram(self, *args, **kwargs): pass def _set_immediate(self, node_id, value): @@ -284,7 +284,7 @@ class NullCache: def clean_unused(self): pass - def poll(self, **kwargs): + def free_ram(self, *args, **kwargs): pass def get(self, node_id): @@ -366,9 +366,10 @@ RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3 class RAMPressureCache(LRUCache): - def __init__(self, key_class): + def __init__(self, key_class, min_headroom=4.0): super().__init__(key_class, 0) self.timestamps = {} + self.min_headroom = min_headroom def clean_unused(self): self._clean_subcaches() @@ -381,19 +382,10 @@ class RAMPressureCache(LRUCache): self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() return super().get(node_id) - def poll(self, ram_headroom): - def _ram_gb(): - return psutil.virtual_memory().available / (1024**3) - - if _ram_gb() > ram_headroom: - return - gc.collect() - if _ram_gb() > ram_headroom: - return - + def _build_clean_list(self): clean_list = [] - for key, (outputs, _), in self.cache.items(): + for key, (_, outputs), in self.cache.items(): oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key]) ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE @@ -416,8 +408,22 @@ class RAMPressureCache(LRUCache): #In the case where we have no information on the node ram usage at all, #break OOM score ties on the last touch timestamp (pure LRU) bisect.insort(clean_list, (oom_score, self.timestamps[key], key)) + return clean_list - while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list: + def free_ram(self, extra_ram=0): + headroom_target = self.min_headroom + (extra_ram / (1024**3)) + def _ram_gb(): + return psutil.virtual_memory().available / (1024**3) + + if _ram_gb() > headroom_target: + return + gc.collect() + if _ram_gb() > headroom_target: + return + + clean_list = self._build_clean_list() + + while _ram_gb() < headroom_target * RAM_CACHE_HYSTERESIS and clean_list: _, _, key = clean_list.pop() del self.cache[key] gc.collect() diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 9d170b16e..caf89ba5f 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -112,6 +112,8 @@ class TopologicalSort: self.blocking = {} # Which nodes are blocked by this node self.externalBlocks = 0 self.unblockedEvent = asyncio.Event() + self.priorities = {} + self.barrierNodes = set() def get_input_info(self, unique_id, input_name): class_type = self.dynprompt.get_node(unique_id)["class_type"] @@ -130,13 +132,37 @@ class TopologicalSort: def add_strong_link(self, from_node_id, from_socket, to_node_id): if not self.is_cached(from_node_id): - self.add_node(from_node_id) + self.add_node(from_node_id, priority=self.priorities.get(to_node_id, 0)) if to_node_id not in self.blocking[from_node_id]: self.blocking[from_node_id][to_node_id] = {} self.blockCount[to_node_id] += 1 self.blocking[from_node_id][to_node_id][from_socket] = True - def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None): + def is_barrier(self, node_id): + return node_id in self.barrierNodes + + def unbarrier(self, node_id): + if not node_id in self.barrierNodes: + return + self.barrierNodes.remove(node_id) + self.priorities[node_id] = self.priorities.get(node_id, 0) + 1 + + links = [] + inputs = self.dynprompt.get_node(node_id)["inputs"] + + for input_name in inputs: + value = inputs[input_name] + if is_link(value): + from_node_id, from_socket = value + _, _, input_info = self.get_input_info(node_id, input_name) + is_barrier = input_info is not None and "cache-barrier" in input_info and input_info["cache-barrier"] + if is_barrier: + links.append((from_node_id, from_socket, node_id)) + + for link in links: + self.add_strong_link(*link) + + def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None, priority=0): node_ids = [node_unique_id] links = [] @@ -148,6 +174,7 @@ class TopologicalSort: self.pendingNodes[unique_id] = True self.blockCount[unique_id] = 0 self.blocking[unique_id] = {} + self.priorities[unique_id] = priority inputs = self.dynprompt.get_node(unique_id)["inputs"] for input_name in inputs: @@ -158,10 +185,13 @@ class TopologicalSort: continue _, _, input_info = self.get_input_info(unique_id, input_name) is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"] - if (include_lazy or not is_lazy): + is_barrier = input_info is not None and "cache-barrier" in input_info and input_info["cache-barrier"] + if (include_lazy or (not is_lazy and not is_barrier)): if not self.is_cached(from_node_id): node_ids.append(from_node_id) links.append((from_node_id, from_socket, unique_id)) + if is_barrier: + self.barrierNodes.add(unique_id) for link in links: self.add_strong_link(*link) @@ -180,7 +210,7 @@ class TopologicalSort: return False def get_ready_nodes(self): - return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0] + return [(self.priorities.get(node_id, 0), node_id) for node_id in self.pendingNodes if self.blockCount[node_id] == 0] def pop_node(self, unique_id): del self.pendingNodes[unique_id] @@ -286,25 +316,34 @@ class ExecutionList(TopologicalSort): class_def = nodes.NODE_CLASS_MAPPINGS[class_type] return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION)) - for node_id in node_list: + priority_level = 0 + priority_nodes = [] + for (priority, node_id) in node_list: + if priority > priority_level: + priority_level = priority + priority_nodes = [] + if priority == priority_level: + priority_nodes.append(node_id) + + for node_id in priority_nodes: if is_output(node_id) or is_async(node_id): return node_id #This should handle the VAEDecode -> preview case - for node_id in node_list: + for node_id in priority_nodes: for blocked_node_id in self.blocking[node_id]: if is_output(blocked_node_id): return node_id #This should handle the VAELoader -> VAEDecode -> preview case - for node_id in node_list: + for node_id in priority_nodes: for blocked_node_id in self.blocking[node_id]: for blocked_node_id1 in self.blocking[blocked_node_id]: if is_output(blocked_node_id1): return node_id #TODO: this function should be improved - return node_list[0] + return priority_nodes[0] def unstage_node_execution(self): assert self.staged_node_id is not None diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index f19adf4b9..987e2637b 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -19,7 +19,7 @@ class BasicScheduler(io.ComfyNode): node_id="BasicScheduler", category="sampling/custom_sampling/schedulers", inputs=[ - io.Model.Input("model"), + io.Model.Input("model", extra_dict={"cache-barrier":True}), io.Combo.Input("scheduler", options=comfy.samplers.SCHEDULER_NAMES), io.Int.Input("steps", default=20, min=1, max=10000), io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), @@ -138,7 +138,7 @@ class SDTurboScheduler(io.ComfyNode): node_id="SDTurboScheduler", category="sampling/custom_sampling/schedulers", inputs=[ - io.Model.Input("model"), + io.Model.Input("model", extra_dict={"cache-barrier":True}), io.Int.Input("steps", default=1, min=1, max=10), io.Float.Input("denoise", default=1.0, min=0, max=1.0, step=0.01), ], @@ -162,7 +162,7 @@ class BetaSamplingScheduler(io.ComfyNode): node_id="BetaSamplingScheduler", category="sampling/custom_sampling/schedulers", inputs=[ - io.Model.Input("model"), + io.Model.Input("model", extra_dict={"cache-barrier":True}), io.Int.Input("steps", default=20, min=1, max=10000), io.Float.Input("alpha", default=0.6, min=0.0, max=50.0, step=0.01, round=False), io.Float.Input("beta", default=0.6, min=0.0, max=50.0, step=0.01, round=False), @@ -352,7 +352,7 @@ class SamplingPercentToSigma(io.ComfyNode): node_id="SamplingPercentToSigma", category="sampling/custom_sampling/sigmas", inputs=[ - io.Model.Input("model"), + io.Model.Input("model", extra_dict={"cache-barrier":True}), io.Float.Input("sampling_percent", default=0.0, min=0.0, max=1.0, step=0.0001), io.Boolean.Input("return_actual_sigma", default=False, tooltip="Return the actual sigma value instead of the value used for interval checks.\nThis only affects results at 0.0 and 1.0."), ], @@ -623,7 +623,7 @@ class SamplerSASolver(io.ComfyNode): node_id="SamplerSASolver", category="sampling/custom_sampling/samplers", inputs=[ - io.Model.Input("model"), + io.Model.Input("model", extra_dict={"cache-barrier":True}), io.Float.Input("eta", default=1.0, min=0.0, max=10.0, step=0.01, round=False), io.Float.Input("sde_start_percent", default=0.2, min=0.0, max=1.0, step=0.001), io.Float.Input("sde_end_percent", default=0.8, min=0.0, max=1.0, step=0.001), @@ -719,7 +719,7 @@ class SamplerCustom(io.ComfyNode): node_id="SamplerCustom", category="sampling/custom_sampling", inputs=[ - io.Model.Input("model"), + io.Model.Input("model", extra_dict={"cache-barrier":True}), io.Boolean.Input("add_noise", default=True), io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True), io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), @@ -784,7 +784,7 @@ class BasicGuider(io.ComfyNode): node_id="BasicGuider", category="sampling/custom_sampling/guiders", inputs=[ - io.Model.Input("model"), + io.Model.Input("model", extra_dict={"cache-barrier":True}), io.Conditioning.Input("conditioning"), ], outputs=[io.Guider.Output()] @@ -805,7 +805,7 @@ class CFGGuider(io.ComfyNode): node_id="CFGGuider", category="sampling/custom_sampling/guiders", inputs=[ - io.Model.Input("model"), + io.Model.Input("model", extra_dict={"cache-barrier":True}), io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), @@ -858,7 +858,7 @@ class DualCFGGuider(io.ComfyNode): node_id="DualCFGGuider", category="sampling/custom_sampling/guiders", inputs=[ - io.Model.Input("model"), + io.Model.Input("model", extra_dict={"cache-barrier":True}), io.Conditioning.Input("cond1"), io.Conditioning.Input("cond2"), io.Conditioning.Input("negative"), @@ -973,7 +973,7 @@ class AddNoise(io.ComfyNode): category="_for_testing/custom_sampling/noise", is_experimental=True, inputs=[ - io.Model.Input("model"), + io.Model.Input("model", extra_dict={"cache-barrier":True}), io.Noise.Input("noise"), io.Sigmas.Input("sigmas"), io.Latent.Input("latent_image"), diff --git a/execution.py b/execution.py index 648f204ec..b42980620 100644 --- a/execution.py +++ b/execution.py @@ -108,7 +108,7 @@ class CacheSet: self.init_null_cache() logging.info("Disabling intermediate node cache.") elif cache_type == CacheType.RAM_PRESSURE: - cache_ram = cache_args.get("ram", 16.0) + cache_ram = cache_args.get("ram", 4.0) self.init_ram_cache(cache_ram) logging.info("Using RAM pressure cache.") elif cache_type == CacheType.LRU: @@ -130,7 +130,7 @@ class CacheSet: self.objects = HierarchicalCache(CacheKeySetID) def init_ram_cache(self, min_headroom): - self.outputs = RAMPressureCache(CacheKeySetInputSignature) + self.outputs = RAMPressureCache(CacheKeySetInputSignature, min_headroom) self.objects = HierarchicalCache(CacheKeySetID) def init_null_cache(self): @@ -427,7 +427,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, input_data_all = None try: - if unique_id in pending_async_nodes: + if execution_list.is_barrier(unique_id): + execution_list.unbarrier(unique_id) + return (ExecutionResult.PENDING, None, None) + elif unique_id in pending_async_nodes: results = [] for r in pending_async_nodes[unique_id]: if isinstance(r, asyncio.Task): @@ -622,13 +625,21 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, class PromptExecutor: def __init__(self, server, cache_type=False, cache_args=None): + self.caches = None self.cache_args = cache_args self.cache_type = cache_type self.server = server self.reset() def reset(self): + if self.caches is not None: + for cache in self.caches.all: + comfy.model_management.unregister_ram_listener(cache) + self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args) + + for cache in self.caches.all: + comfy.model_management.register_ram_listener(cache) self.status_messages = [] self.success = True @@ -728,7 +739,7 @@ class PromptExecutor: execution_list.unstage_node_execution() else: # result == ExecutionResult.SUCCESS: execution_list.complete_node_execution() - self.caches.outputs.poll(ram_headroom=self.cache_args["ram"]) + self.caches.outputs.free_ram() else: # Only execute when the while-loop ends without break self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) diff --git a/nodes.py b/nodes.py index 56b74ebe3..6794f6f05 100644 --- a/nodes.py +++ b/nodes.py @@ -60,7 +60,7 @@ class CLIPTextEncode(ComfyNodeABC): return { "required": { "text": (IO.STRING, {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}), - "clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."}) + "clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text.", "cache-barrier" : True}) } } RETURN_TYPES = (IO.CONDITIONING,) @@ -1518,7 +1518,7 @@ class KSampler: def INPUT_TYPES(s): return { "required": { - "model": ("MODEL", {"tooltip": "The model used for denoising the input latent."}), + "model": ("MODEL", {"tooltip": "The model used for denoising the input latent.", "cache-barrier": True}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}), "steps": ("INT", {"default": 20, "min": 1, "max": 10000, "tooltip": "The number of steps used in the denoising process."}), "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01, "tooltip": "The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."}), @@ -1545,7 +1545,7 @@ class KSamplerAdvanced: @classmethod def INPUT_TYPES(s): return {"required": - {"model": ("MODEL",), + {"model": ("MODEL", {"cache-barrier": True}), "add_noise": (["enable", "disable"], ), "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True}), "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),