mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-07 03:52:32 +08:00
Compare commits
5 Commits
be25470d88
...
f7c36abe7f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f7c36abe7f | ||
|
|
c05a08ae66 | ||
|
|
de9ada6a41 | ||
|
|
37f711d4a1 | ||
|
|
c463359308 |
@ -19,7 +19,8 @@
|
||||
import psutil
|
||||
import logging
|
||||
from enum import Enum
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
||||
import threading
|
||||
import torch
|
||||
import sys
|
||||
import platform
|
||||
@ -650,7 +651,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
||||
soft_empty_cache()
|
||||
return unloaded_models
|
||||
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
cleanup_models_gc()
|
||||
global vram_state
|
||||
|
||||
@ -746,6 +747,26 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
current_loaded_models.insert(0, loaded_model)
|
||||
return
|
||||
|
||||
def load_models_gpu_thread(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load):
|
||||
with torch.inference_mode():
|
||||
load_models_gpu_orig(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
|
||||
soft_empty_cache()
|
||||
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
#Deliberately load models outside of the Aimdo mempool so they can be retained accross
|
||||
#nodes. Use a dummy thread to do it as pytorch documents that mempool contexts are
|
||||
#thread local. So exploit that to escape context
|
||||
if enables_dynamic_vram():
|
||||
t = threading.Thread(
|
||||
target=load_models_gpu_thread,
|
||||
args=(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
|
||||
)
|
||||
t.start()
|
||||
t.join()
|
||||
else:
|
||||
load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights,
|
||||
minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
|
||||
|
||||
def load_model_gpu(model):
|
||||
return load_models_gpu([model])
|
||||
|
||||
@ -1112,11 +1133,11 @@ def get_cast_buffer(offload_stream, device, size, ref):
|
||||
return None
|
||||
if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2):
|
||||
#I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now
|
||||
torch.cuda.synchronize()
|
||||
synchronize()
|
||||
del STREAM_CAST_BUFFERS[offload_stream]
|
||||
del cast_buffer
|
||||
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
|
||||
torch.cuda.empty_cache()
|
||||
soft_empty_cache()
|
||||
with wf_context:
|
||||
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
|
||||
STREAM_CAST_BUFFERS[offload_stream] = cast_buffer
|
||||
@ -1132,9 +1153,7 @@ def reset_cast_buffers():
|
||||
for offload_stream in STREAM_CAST_BUFFERS:
|
||||
offload_stream.synchronize()
|
||||
STREAM_CAST_BUFFERS.clear()
|
||||
if comfy.memory_management.aimdo_allocator is None:
|
||||
#Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist
|
||||
torch.cuda.empty_cache()
|
||||
soft_empty_cache()
|
||||
|
||||
def get_offload_stream(device):
|
||||
stream_counter = stream_counters.get(device, 0)
|
||||
@ -1284,7 +1303,7 @@ def discard_cuda_async_error():
|
||||
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||
_ = a + b
|
||||
torch.cuda.synchronize()
|
||||
synchronize()
|
||||
except torch.AcceleratorError:
|
||||
#Dump it! We already know about it from the synchronous return
|
||||
pass
|
||||
@ -1688,6 +1707,12 @@ def lora_compute_dtype(device):
|
||||
LORA_COMPUTE_DTYPES[device] = dtype
|
||||
return dtype
|
||||
|
||||
def synchronize():
|
||||
if is_intel_xpu():
|
||||
torch.xpu.synchronize()
|
||||
elif torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def soft_empty_cache(force=False):
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.MPS:
|
||||
@ -1713,9 +1738,6 @@ def debug_memory_summary():
|
||||
return torch.cuda.memory.memory_summary()
|
||||
return ""
|
||||
|
||||
#TODO: might be cleaner to put this somewhere else
|
||||
import threading
|
||||
|
||||
class InterruptProcessingException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@ -1597,7 +1597,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
|
||||
if unpatch_weights:
|
||||
self.partially_unload_ram(1e32)
|
||||
self.partially_unload(None)
|
||||
self.partially_unload(None, 1e32)
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||
assert not force_patch_weights #See above
|
||||
|
||||
169
execution.py
169
execution.py
@ -48,6 +48,167 @@ class ExecutionResult(Enum):
|
||||
class DuplicateNodeError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# ======================================================================================
|
||||
# ADDED: Node grouping helpers for "input-type locality" execution ordering
|
||||
# --------------------------------------------------------------------------------------
|
||||
# We cluster ready-to-run nodes by a signature derived from:
|
||||
# - Declared INPUT_TYPES (required/optional socket types)
|
||||
# - Upstream linked RETURN_TYPES (when available from prompt links)
|
||||
#
|
||||
# This is a SCHEDULING optimization only:
|
||||
# - It must not change correctness or dependency satisfaction.
|
||||
# - It only reorders nodes that ExecutionList already deems ready/executable.
|
||||
# - It is stable to avoid churn and to preserve deterministic behavior.
|
||||
#
|
||||
# IMPORTANT: ExecutionList is imported from comfy_execution.graph; we avoid invasive
|
||||
# changes by using a small subclass + defensive introspection of its internal queues.
|
||||
# ======================================================================================
|
||||
|
||||
def _safe_stringify_type(t):
|
||||
try:
|
||||
return str(t)
|
||||
except Exception:
|
||||
return repr(t)
|
||||
|
||||
def _node_input_signature_from_prompt(prompt: dict, node_id: str):
|
||||
"""
|
||||
Build a stable, hashable signature representing a node's *input requirements*.
|
||||
|
||||
Includes:
|
||||
- Declared input socket types via INPUT_TYPES() (required + optional)
|
||||
- Linked upstream output RETURN_TYPES, when input is a link
|
||||
|
||||
This signature is used ONLY for grouping/sorting ready nodes.
|
||||
"""
|
||||
node = prompt.get(node_id)
|
||||
if node is None:
|
||||
return ("<missing-node>", node_id)
|
||||
|
||||
class_type = node.get("class_type")
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS.get(class_type)
|
||||
if class_def is None:
|
||||
return ("<missing-class>", class_type, node_id)
|
||||
|
||||
sig = []
|
||||
|
||||
# Declared socket types (required/optional)
|
||||
try:
|
||||
input_types = class_def.INPUT_TYPES()
|
||||
except Exception:
|
||||
input_types = {}
|
||||
|
||||
for cat in ("required", "optional"):
|
||||
cat_dict = input_types.get(cat, {})
|
||||
if isinstance(cat_dict, dict):
|
||||
# Sort keys for stability
|
||||
for k in sorted(cat_dict.keys()):
|
||||
v = cat_dict[k]
|
||||
sig.append(("decl", cat, k, _safe_stringify_type(v)))
|
||||
|
||||
# Linked upstream return types (helps cluster by latent/model flows)
|
||||
inputs = node.get("inputs", {}) or {}
|
||||
if isinstance(inputs, dict):
|
||||
for k in sorted(inputs.keys()):
|
||||
v = inputs[k]
|
||||
if is_link(v) and isinstance(v, (list, tuple)) and len(v) == 2:
|
||||
src_id, out_idx = v[0], v[1]
|
||||
src_node = prompt.get(src_id)
|
||||
if src_node is None:
|
||||
sig.append(("link", k, "<missing-src-node>"))
|
||||
continue
|
||||
src_class_type = src_node.get("class_type")
|
||||
src_class_def = nodes.NODE_CLASS_MAPPINGS.get(src_class_type)
|
||||
if src_class_def is None:
|
||||
sig.append(("link", k, "<missing-src-class>", src_class_type))
|
||||
continue
|
||||
ret_types = getattr(src_class_def, "RETURN_TYPES", ())
|
||||
try:
|
||||
if isinstance(out_idx, int) and out_idx < len(ret_types):
|
||||
sig.append(("link", k, _safe_stringify_type(ret_types[out_idx])))
|
||||
else:
|
||||
sig.append(("link", k, "<bad-out-idx>", _safe_stringify_type(out_idx)))
|
||||
except Exception:
|
||||
sig.append(("link", k, "<ret-type-error>"))
|
||||
|
||||
return tuple(sig)
|
||||
|
||||
def _try_group_sort_execution_list_ready_nodes(execution_list: ExecutionList, prompt: dict):
|
||||
"""
|
||||
Attempt to reorder the ExecutionList's *ready* nodes in-place, grouping by input signature.
|
||||
|
||||
This is intentionally defensive because ExecutionList is external; we only touch
|
||||
well-known/observed internal attributes when they match expected shapes.
|
||||
|
||||
Supported patterns (best-effort):
|
||||
- execution_list.nodes_to_execute : list[node_id, ...]
|
||||
- execution_list._nodes_to_execute : list[node_id, ...] (fallback)
|
||||
|
||||
We DO NOT rewrite heaps/tuples with priority keys, because that risks breaking invariants.
|
||||
If the internal structure is not a simple list of node_ids, we do nothing.
|
||||
"""
|
||||
# Candidate attribute names that (in some ComfyUI revisions) hold ready-to-run node IDs
|
||||
candidates = ("nodes_to_execute", "_nodes_to_execute")
|
||||
for attr in candidates:
|
||||
if not hasattr(execution_list, attr):
|
||||
continue
|
||||
value = getattr(execution_list, attr)
|
||||
|
||||
# Only operate on a plain list of node ids (strings/ints)
|
||||
if isinstance(value, list) and all(isinstance(x, (str, int)) for x in value):
|
||||
# Stable grouping sort:
|
||||
# primary: signature (to cluster similar input requirements)
|
||||
# secondary: original order (stability)
|
||||
# NOTE: include length of signature in key to reduce expensive stringification
|
||||
indexed = list(enumerate(value))
|
||||
indexed.sort(
|
||||
key=lambda it: (
|
||||
# signature key
|
||||
_node_input_signature_from_prompt(prompt, str(it[1])),
|
||||
# keep stable within same signature
|
||||
it[0],
|
||||
)
|
||||
)
|
||||
new_list = [node_id for _, node_id in indexed]
|
||||
setattr(execution_list, attr, new_list)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class GroupedExecutionList(ExecutionList):
|
||||
"""
|
||||
ADDED: Thin wrapper around ExecutionList that reorders *ready* nodes before staging
|
||||
to improve model/tensor locality (reduce VRAM/RAM chatter).
|
||||
|
||||
This does not change dependency logic; it only reorders nodes that are already ready.
|
||||
"""
|
||||
|
||||
def _apply_group_sort_if_possible(self):
|
||||
try:
|
||||
# dynprompt.original_prompt is the canonical prompt graph dict
|
||||
prompt = getattr(self, "dynprompt", None)
|
||||
prompt_dict = None
|
||||
if prompt is not None:
|
||||
prompt_dict = getattr(prompt, "original_prompt", None)
|
||||
if isinstance(prompt_dict, dict):
|
||||
_try_group_sort_execution_list_ready_nodes(self, prompt_dict)
|
||||
except Exception:
|
||||
# Must never break execution
|
||||
pass
|
||||
|
||||
# NOTE: stage_node_execution is awaited in the caller in this file, so we keep it async-compatible.
|
||||
async def stage_node_execution(self):
|
||||
# Group-sort the ready list *before* choosing next node
|
||||
self._apply_group_sort_if_possible()
|
||||
return await super().stage_node_execution()
|
||||
|
||||
def add_node(self, node_id):
|
||||
# Keep original behavior, then regroup for future staging
|
||||
super().add_node(node_id)
|
||||
self._apply_group_sort_if_possible()
|
||||
|
||||
|
||||
class IsChangedCache:
|
||||
def __init__(self, prompt_id: str, dynprompt: DynamicPrompt, outputs_cache: BasicCache):
|
||||
self.prompt_id = prompt_id
|
||||
@ -721,7 +882,13 @@ class PromptExecutor:
|
||||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||
ui_node_outputs = {}
|
||||
executed = set()
|
||||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||
|
||||
# ==================================================================================
|
||||
# CHANGED: Use GroupedExecutionList to group ready-to-run nodes by input signature.
|
||||
# This reduces VRAM/RAM chatter when workflows reuse the same models/tensor types.
|
||||
# ==================================================================================
|
||||
execution_list = GroupedExecutionList(dynamic_prompt, self.caches.outputs)
|
||||
|
||||
current_outputs = self.caches.outputs.all_node_ids()
|
||||
for node_id in list(execute_outputs):
|
||||
execution_list.add_node(node_id)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user