mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-18 18:43:05 +08:00
Merge branch 'master' into dr-support-pip-cm
This commit is contained in:
commit
aaf06ace12
@ -201,6 +201,8 @@ Python 3.14 will work if you comment out the `kornia` dependency in the requirem
|
|||||||
|
|
||||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
||||||
|
|
||||||
|
### Instructions:
|
||||||
|
|
||||||
Git clone this repo.
|
Git clone this repo.
|
||||||
|
|
||||||
Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
|
Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
|
||||||
|
|||||||
112
app/subgraph_manager.py
Normal file
112
app/subgraph_manager.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TypedDict
|
||||||
|
import os
|
||||||
|
import folder_paths
|
||||||
|
import glob
|
||||||
|
from aiohttp import web
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
|
||||||
|
class Source:
|
||||||
|
custom_node = "custom_node"
|
||||||
|
|
||||||
|
class SubgraphEntry(TypedDict):
|
||||||
|
source: str
|
||||||
|
"""
|
||||||
|
Source of subgraph - custom_nodes vs templates.
|
||||||
|
"""
|
||||||
|
path: str
|
||||||
|
"""
|
||||||
|
Relative path of the subgraph file.
|
||||||
|
For custom nodes, will be the relative directory like <custom_node_dir>/subgraphs/<name>.json
|
||||||
|
"""
|
||||||
|
name: str
|
||||||
|
"""
|
||||||
|
Name of subgraph file.
|
||||||
|
"""
|
||||||
|
info: CustomNodeSubgraphEntryInfo
|
||||||
|
"""
|
||||||
|
Additional info about subgraph; in the case of custom_nodes, will contain nodepack name
|
||||||
|
"""
|
||||||
|
data: str
|
||||||
|
|
||||||
|
class CustomNodeSubgraphEntryInfo(TypedDict):
|
||||||
|
node_pack: str
|
||||||
|
"""Node pack name."""
|
||||||
|
|
||||||
|
class SubgraphManager:
|
||||||
|
def __init__(self):
|
||||||
|
self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
|
||||||
|
|
||||||
|
async def load_entry_data(self, entry: SubgraphEntry):
|
||||||
|
with open(entry['path'], 'r') as f:
|
||||||
|
entry['data'] = f.read()
|
||||||
|
return entry
|
||||||
|
|
||||||
|
async def sanitize_entry(self, entry: SubgraphEntry | None, remove_data=False) -> SubgraphEntry | None:
|
||||||
|
if entry is None:
|
||||||
|
return None
|
||||||
|
entry = entry.copy()
|
||||||
|
entry.pop('path', None)
|
||||||
|
if remove_data:
|
||||||
|
entry.pop('data', None)
|
||||||
|
return entry
|
||||||
|
|
||||||
|
async def sanitize_entries(self, entries: dict[str, SubgraphEntry], remove_data=False) -> dict[str, SubgraphEntry]:
|
||||||
|
entries = entries.copy()
|
||||||
|
for key in list(entries.keys()):
|
||||||
|
entries[key] = await self.sanitize_entry(entries[key], remove_data)
|
||||||
|
return entries
|
||||||
|
|
||||||
|
async def get_custom_node_subgraphs(self, loadedModules, force_reload=False):
|
||||||
|
# if not forced to reload and cached, return cache
|
||||||
|
if not force_reload and self.cached_custom_node_subgraphs is not None:
|
||||||
|
return self.cached_custom_node_subgraphs
|
||||||
|
# Load subgraphs from custom nodes
|
||||||
|
subfolder = "subgraphs"
|
||||||
|
subgraphs_dict: dict[SubgraphEntry] = {}
|
||||||
|
|
||||||
|
for folder in folder_paths.get_folder_paths("custom_nodes"):
|
||||||
|
pattern = os.path.join(folder, f"*/{subfolder}/*.json")
|
||||||
|
matched_files = glob.glob(pattern)
|
||||||
|
for file in matched_files:
|
||||||
|
# replace backslashes with forward slashes
|
||||||
|
file = file.replace('\\', '/')
|
||||||
|
info: CustomNodeSubgraphEntryInfo = {
|
||||||
|
"node_pack": "custom_nodes." + file.split('/')[-3]
|
||||||
|
}
|
||||||
|
source = Source.custom_node
|
||||||
|
# hash source + path to make sure id will be as unique as possible, but
|
||||||
|
# reproducible across backend reloads
|
||||||
|
id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
|
||||||
|
entry: SubgraphEntry = {
|
||||||
|
"source": Source.custom_node,
|
||||||
|
"name": os.path.splitext(os.path.basename(file))[0],
|
||||||
|
"path": file,
|
||||||
|
"info": info,
|
||||||
|
}
|
||||||
|
subgraphs_dict[id] = entry
|
||||||
|
self.cached_custom_node_subgraphs = subgraphs_dict
|
||||||
|
return subgraphs_dict
|
||||||
|
|
||||||
|
async def get_custom_node_subgraph(self, id: str, loadedModules):
|
||||||
|
subgraphs = await self.get_custom_node_subgraphs(loadedModules)
|
||||||
|
entry: SubgraphEntry = subgraphs.get(id, None)
|
||||||
|
if entry is not None and entry.get('data', None) is None:
|
||||||
|
await self.load_entry_data(entry)
|
||||||
|
return entry
|
||||||
|
|
||||||
|
def add_routes(self, routes, loadedModules):
|
||||||
|
@routes.get("/global_subgraphs")
|
||||||
|
async def get_global_subgraphs(request):
|
||||||
|
subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules)
|
||||||
|
# NOTE: we may want to include other sources of global subgraphs such as templates in the future;
|
||||||
|
# that's the reasoning for the current implementation
|
||||||
|
return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True))
|
||||||
|
|
||||||
|
@routes.get("/global_subgraphs/{id}")
|
||||||
|
async def get_global_subgraph(request):
|
||||||
|
id = request.match_info.get("id", None)
|
||||||
|
subgraph = await self.get_custom_node_subgraph(id, loadedModules)
|
||||||
|
return web.json_response(await self.sanitize_entry(subgraph))
|
||||||
@ -330,15 +330,21 @@ except:
|
|||||||
|
|
||||||
|
|
||||||
SUPPORT_FP8_OPS = args.supports_fp8_compute
|
SUPPORT_FP8_OPS = args.supports_fp8_compute
|
||||||
|
|
||||||
|
AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if is_amd():
|
if is_amd():
|
||||||
|
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||||
|
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
|
||||||
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
||||||
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
|
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
|
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
|
||||||
except:
|
except:
|
||||||
rocm_version = (6, -1)
|
rocm_version = (6, -1)
|
||||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
|
||||||
logging.info("AMD arch: {}".format(arch))
|
logging.info("AMD arch: {}".format(arch))
|
||||||
logging.info("ROCm version: {}".format(rocm_version))
|
logging.info("ROCm version: {}".format(rocm_version))
|
||||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
@ -1331,7 +1337,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
|
|
||||||
if is_amd():
|
if is_amd():
|
||||||
arch = torch.cuda.get_device_properties(device).gcnArchName
|
arch = torch.cuda.get_device_properties(device).gcnArchName
|
||||||
if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16
|
if any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH): # RDNA2 and older don't support bf16
|
||||||
if manual_cast:
|
if manual_cast:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|||||||
@ -265,6 +265,26 @@ class HierarchicalCache(BasicCache):
|
|||||||
assert cache is not None
|
assert cache is not None
|
||||||
return await cache._ensure_subcache(node_id, children_ids)
|
return await cache._ensure_subcache(node_id, children_ids)
|
||||||
|
|
||||||
|
class NullCache:
|
||||||
|
|
||||||
|
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def all_node_ids(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def clean_unused(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get(self, node_id):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set(self, node_id, value):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def ensure_subcache_for(self, node_id, children_ids):
|
||||||
|
return self
|
||||||
|
|
||||||
class LRUCache(BasicCache):
|
class LRUCache(BasicCache):
|
||||||
def __init__(self, key_class, max_size=100):
|
def __init__(self, key_class, max_size=100):
|
||||||
super().__init__(key_class)
|
super().__init__(key_class)
|
||||||
@ -316,157 +336,3 @@ class LRUCache(BasicCache):
|
|||||||
self._mark_used(child_id)
|
self._mark_used(child_id)
|
||||||
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
class DependencyAwareCache(BasicCache):
|
|
||||||
"""
|
|
||||||
A cache implementation that tracks dependencies between nodes and manages
|
|
||||||
their execution and caching accordingly. It extends the BasicCache class.
|
|
||||||
Nodes are removed from this cache once all of their descendants have been
|
|
||||||
executed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, key_class):
|
|
||||||
"""
|
|
||||||
Initialize the DependencyAwareCache.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key_class: The class used for generating cache keys.
|
|
||||||
"""
|
|
||||||
super().__init__(key_class)
|
|
||||||
self.descendants = {} # Maps node_id -> set of descendant node_ids
|
|
||||||
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
|
|
||||||
self.executed_nodes = set() # Tracks nodes that have been executed
|
|
||||||
|
|
||||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
|
||||||
"""
|
|
||||||
Clear the entire cache and rebuild the dependency graph.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dynprompt: The dynamic prompt object containing node information.
|
|
||||||
node_ids: List of node IDs to initialize the cache for.
|
|
||||||
is_changed_cache: Flag indicating if the cache has changed.
|
|
||||||
"""
|
|
||||||
# Clear all existing cache data
|
|
||||||
self.cache.clear()
|
|
||||||
self.subcaches.clear()
|
|
||||||
self.descendants.clear()
|
|
||||||
self.ancestors.clear()
|
|
||||||
self.executed_nodes.clear()
|
|
||||||
|
|
||||||
# Call the parent method to initialize the cache with the new prompt
|
|
||||||
await super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
|
||||||
|
|
||||||
# Rebuild the dependency graph
|
|
||||||
self._build_dependency_graph(dynprompt, node_ids)
|
|
||||||
|
|
||||||
def _build_dependency_graph(self, dynprompt, node_ids):
|
|
||||||
"""
|
|
||||||
Build the dependency graph for all nodes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dynprompt: The dynamic prompt object containing node information.
|
|
||||||
node_ids: List of node IDs to build the graph for.
|
|
||||||
"""
|
|
||||||
self.descendants.clear()
|
|
||||||
self.ancestors.clear()
|
|
||||||
for node_id in node_ids:
|
|
||||||
self.descendants[node_id] = set()
|
|
||||||
self.ancestors[node_id] = set()
|
|
||||||
|
|
||||||
for node_id in node_ids:
|
|
||||||
inputs = dynprompt.get_node(node_id)["inputs"]
|
|
||||||
for input_data in inputs.values():
|
|
||||||
if is_link(input_data): # Check if the input is a link to another node
|
|
||||||
ancestor_id = input_data[0]
|
|
||||||
self.descendants[ancestor_id].add(node_id)
|
|
||||||
self.ancestors[node_id].add(ancestor_id)
|
|
||||||
|
|
||||||
def set(self, node_id, value):
|
|
||||||
"""
|
|
||||||
Mark a node as executed and store its value in the cache.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node_id: The ID of the node to store.
|
|
||||||
value: The value to store for the node.
|
|
||||||
"""
|
|
||||||
self._set_immediate(node_id, value)
|
|
||||||
self.executed_nodes.add(node_id)
|
|
||||||
self._cleanup_ancestors(node_id)
|
|
||||||
|
|
||||||
def get(self, node_id):
|
|
||||||
"""
|
|
||||||
Retrieve the cached value for a node.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node_id: The ID of the node to retrieve.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The cached value for the node.
|
|
||||||
"""
|
|
||||||
return self._get_immediate(node_id)
|
|
||||||
|
|
||||||
async def ensure_subcache_for(self, node_id, children_ids):
|
|
||||||
"""
|
|
||||||
Ensure a subcache exists for a node and update dependencies.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node_id: The ID of the parent node.
|
|
||||||
children_ids: List of child node IDs to associate with the parent node.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The subcache object for the node.
|
|
||||||
"""
|
|
||||||
subcache = await super()._ensure_subcache(node_id, children_ids)
|
|
||||||
for child_id in children_ids:
|
|
||||||
self.descendants[node_id].add(child_id)
|
|
||||||
self.ancestors[child_id].add(node_id)
|
|
||||||
return subcache
|
|
||||||
|
|
||||||
def _cleanup_ancestors(self, node_id):
|
|
||||||
"""
|
|
||||||
Check if ancestors of a node can be removed from the cache.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node_id: The ID of the node whose ancestors are to be checked.
|
|
||||||
"""
|
|
||||||
for ancestor_id in self.ancestors.get(node_id, []):
|
|
||||||
if ancestor_id in self.executed_nodes:
|
|
||||||
# Remove ancestor if all its descendants have been executed
|
|
||||||
if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]):
|
|
||||||
self._remove_node(ancestor_id)
|
|
||||||
|
|
||||||
def _remove_node(self, node_id):
|
|
||||||
"""
|
|
||||||
Remove a node from the cache.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node_id: The ID of the node to remove.
|
|
||||||
"""
|
|
||||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
|
||||||
if cache_key in self.cache:
|
|
||||||
del self.cache[cache_key]
|
|
||||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
|
||||||
if subcache_key in self.subcaches:
|
|
||||||
del self.subcaches[subcache_key]
|
|
||||||
|
|
||||||
def clean_unused(self):
|
|
||||||
"""
|
|
||||||
Clean up unused nodes. This is a no-op for this cache implementation.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def recursive_debug_dump(self):
|
|
||||||
"""
|
|
||||||
Dump the cache and dependency graph for debugging.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list containing the cache state and dependency graph.
|
|
||||||
"""
|
|
||||||
result = super().recursive_debug_dump()
|
|
||||||
result.append({
|
|
||||||
"descendants": self.descendants,
|
|
||||||
"ancestors": self.ancestors,
|
|
||||||
"executed_nodes": list(self.executed_nodes),
|
|
||||||
})
|
|
||||||
return result
|
|
||||||
|
|||||||
@ -153,7 +153,8 @@ class TopologicalSort:
|
|||||||
continue
|
continue
|
||||||
_, _, input_info = self.get_input_info(unique_id, input_name)
|
_, _, input_info = self.get_input_info(unique_id, input_name)
|
||||||
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
||||||
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
|
if (include_lazy or not is_lazy):
|
||||||
|
if not self.is_cached(from_node_id):
|
||||||
node_ids.append(from_node_id)
|
node_ids.append(from_node_id)
|
||||||
links.append((from_node_id, from_socket, unique_id))
|
links.append((from_node_id, from_socket, unique_id))
|
||||||
|
|
||||||
@ -194,10 +195,35 @@ class ExecutionList(TopologicalSort):
|
|||||||
super().__init__(dynprompt)
|
super().__init__(dynprompt)
|
||||||
self.output_cache = output_cache
|
self.output_cache = output_cache
|
||||||
self.staged_node_id = None
|
self.staged_node_id = None
|
||||||
|
self.execution_cache = {}
|
||||||
|
self.execution_cache_listeners = {}
|
||||||
|
|
||||||
def is_cached(self, node_id):
|
def is_cached(self, node_id):
|
||||||
return self.output_cache.get(node_id) is not None
|
return self.output_cache.get(node_id) is not None
|
||||||
|
|
||||||
|
def cache_link(self, from_node_id, to_node_id):
|
||||||
|
if not to_node_id in self.execution_cache:
|
||||||
|
self.execution_cache[to_node_id] = {}
|
||||||
|
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id)
|
||||||
|
if not from_node_id in self.execution_cache_listeners:
|
||||||
|
self.execution_cache_listeners[from_node_id] = set()
|
||||||
|
self.execution_cache_listeners[from_node_id].add(to_node_id)
|
||||||
|
|
||||||
|
def get_output_cache(self, from_node_id, to_node_id):
|
||||||
|
if not to_node_id in self.execution_cache:
|
||||||
|
return None
|
||||||
|
return self.execution_cache[to_node_id].get(from_node_id)
|
||||||
|
|
||||||
|
def cache_update(self, node_id, value):
|
||||||
|
if node_id in self.execution_cache_listeners:
|
||||||
|
for to_node_id in self.execution_cache_listeners[node_id]:
|
||||||
|
if to_node_id in self.execution_cache:
|
||||||
|
self.execution_cache[to_node_id][node_id] = value
|
||||||
|
|
||||||
|
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||||
|
super().add_strong_link(from_node_id, from_socket, to_node_id)
|
||||||
|
self.cache_link(from_node_id, to_node_id)
|
||||||
|
|
||||||
async def stage_node_execution(self):
|
async def stage_node_execution(self):
|
||||||
assert self.staged_node_id is None
|
assert self.staged_node_id is None
|
||||||
if self.is_empty():
|
if self.is_empty():
|
||||||
@ -277,6 +303,8 @@ class ExecutionList(TopologicalSort):
|
|||||||
def complete_node_execution(self):
|
def complete_node_execution(self):
|
||||||
node_id = self.staged_node_id
|
node_id = self.staged_node_id
|
||||||
self.pop_node(node_id)
|
self.pop_node(node_id)
|
||||||
|
self.execution_cache.pop(node_id, None)
|
||||||
|
self.execution_cache_listeners.pop(node_id, None)
|
||||||
self.staged_node_id = None
|
self.staged_node_id = None
|
||||||
|
|
||||||
def get_nodes_in_cycle(self):
|
def get_nodes_in_cycle(self):
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.65"
|
__version__ = "0.3.66"
|
||||||
|
|||||||
34
execution.py
34
execution.py
@ -18,7 +18,7 @@ from comfy_execution.caching import (
|
|||||||
BasicCache,
|
BasicCache,
|
||||||
CacheKeySetID,
|
CacheKeySetID,
|
||||||
CacheKeySetInputSignature,
|
CacheKeySetInputSignature,
|
||||||
DependencyAwareCache,
|
NullCache,
|
||||||
HierarchicalCache,
|
HierarchicalCache,
|
||||||
LRUCache,
|
LRUCache,
|
||||||
)
|
)
|
||||||
@ -91,13 +91,13 @@ class IsChangedCache:
|
|||||||
class CacheType(Enum):
|
class CacheType(Enum):
|
||||||
CLASSIC = 0
|
CLASSIC = 0
|
||||||
LRU = 1
|
LRU = 1
|
||||||
DEPENDENCY_AWARE = 2
|
NONE = 2
|
||||||
|
|
||||||
|
|
||||||
class CacheSet:
|
class CacheSet:
|
||||||
def __init__(self, cache_type=None, cache_size=None):
|
def __init__(self, cache_type=None, cache_size=None):
|
||||||
if cache_type == CacheType.DEPENDENCY_AWARE:
|
if cache_type == CacheType.NONE:
|
||||||
self.init_dependency_aware_cache()
|
self.init_null_cache()
|
||||||
logging.info("Disabling intermediate node cache.")
|
logging.info("Disabling intermediate node cache.")
|
||||||
elif cache_type == CacheType.LRU:
|
elif cache_type == CacheType.LRU:
|
||||||
if cache_size is None:
|
if cache_size is None:
|
||||||
@ -120,11 +120,12 @@ class CacheSet:
|
|||||||
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||||
self.objects = HierarchicalCache(CacheKeySetID)
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
# only hold cached items while the decendents have not executed
|
def init_null_cache(self):
|
||||||
def init_dependency_aware_cache(self):
|
self.outputs = NullCache()
|
||||||
self.outputs = DependencyAwareCache(CacheKeySetInputSignature)
|
#The UI cache is expected to be iterable at the end of each workflow
|
||||||
self.ui = DependencyAwareCache(CacheKeySetInputSignature)
|
#so it must cache at least a full workflow. Use Heirachical
|
||||||
self.objects = DependencyAwareCache(CacheKeySetID)
|
self.ui = HierarchicalCache(CacheKeySetInputSignature)
|
||||||
|
self.objects = NullCache()
|
||||||
|
|
||||||
def recursive_debug_dump(self):
|
def recursive_debug_dump(self):
|
||||||
result = {
|
result = {
|
||||||
@ -135,7 +136,7 @@ class CacheSet:
|
|||||||
|
|
||||||
SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
|
SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
|
||||||
|
|
||||||
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
|
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}):
|
||||||
is_v3 = issubclass(class_def, _ComfyNodeInternal)
|
is_v3 = issubclass(class_def, _ComfyNodeInternal)
|
||||||
if is_v3:
|
if is_v3:
|
||||||
valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True)
|
valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True)
|
||||||
@ -153,10 +154,10 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
|||||||
if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)):
|
if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)):
|
||||||
input_unique_id = input_data[0]
|
input_unique_id = input_data[0]
|
||||||
output_index = input_data[1]
|
output_index = input_data[1]
|
||||||
if outputs is None:
|
if execution_list is None:
|
||||||
mark_missing()
|
mark_missing()
|
||||||
continue # This might be a lazily-evaluated input
|
continue # This might be a lazily-evaluated input
|
||||||
cached_output = outputs.get(input_unique_id)
|
cached_output = execution_list.get_output_cache(input_unique_id, unique_id)
|
||||||
if cached_output is None:
|
if cached_output is None:
|
||||||
mark_missing()
|
mark_missing()
|
||||||
continue
|
continue
|
||||||
@ -405,6 +406,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
cached_output = caches.ui.get(unique_id) or {}
|
cached_output = caches.ui.get(unique_id) or {}
|
||||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
||||||
get_progress_state().finish_progress(unique_id)
|
get_progress_state().finish_progress(unique_id)
|
||||||
|
execution_list.cache_update(unique_id, caches.outputs.get(unique_id))
|
||||||
return (ExecutionResult.SUCCESS, None, None)
|
return (ExecutionResult.SUCCESS, None, None)
|
||||||
|
|
||||||
input_data_all = None
|
input_data_all = None
|
||||||
@ -434,7 +436,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
for r in result:
|
for r in result:
|
||||||
if is_link(r):
|
if is_link(r):
|
||||||
source_node, source_output = r[0], r[1]
|
source_node, source_output = r[0], r[1]
|
||||||
node_output = caches.outputs.get(source_node)[source_output]
|
node_output = execution_list.get_output_cache(source_node, unique_id)[source_output]
|
||||||
for o in node_output:
|
for o in node_output:
|
||||||
resolved_output.append(o)
|
resolved_output.append(o)
|
||||||
|
|
||||||
@ -446,7 +448,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
has_subgraph = False
|
has_subgraph = False
|
||||||
else:
|
else:
|
||||||
get_progress_state().start_progress(unique_id)
|
get_progress_state().start_progress(unique_id)
|
||||||
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.last_node_id = display_node_id
|
server.last_node_id = display_node_id
|
||||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||||
@ -549,11 +551,15 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
subcache.clean_unused()
|
subcache.clean_unused()
|
||||||
for node_id in new_output_ids:
|
for node_id in new_output_ids:
|
||||||
execution_list.add_node(node_id)
|
execution_list.add_node(node_id)
|
||||||
|
execution_list.cache_link(node_id, unique_id)
|
||||||
for link in new_output_links:
|
for link in new_output_links:
|
||||||
execution_list.add_strong_link(link[0], link[1], unique_id)
|
execution_list.add_strong_link(link[0], link[1], unique_id)
|
||||||
pending_subgraph_results[unique_id] = cached_outputs
|
pending_subgraph_results[unique_id] = cached_outputs
|
||||||
return (ExecutionResult.PENDING, None, None)
|
return (ExecutionResult.PENDING, None, None)
|
||||||
|
|
||||||
caches.outputs.set(unique_id, output_data)
|
caches.outputs.set(unique_id, output_data)
|
||||||
|
execution_list.cache_update(unique_id, output_data)
|
||||||
|
|
||||||
except comfy.model_management.InterruptProcessingException as iex:
|
except comfy.model_management.InterruptProcessingException as iex:
|
||||||
logging.info("Processing interrupted")
|
logging.info("Processing interrupted")
|
||||||
|
|
||||||
|
|||||||
2
main.py
2
main.py
@ -199,7 +199,7 @@ def prompt_worker(q, server_instance):
|
|||||||
if args.cache_lru > 0:
|
if args.cache_lru > 0:
|
||||||
cache_type = execution.CacheType.LRU
|
cache_type = execution.CacheType.LRU
|
||||||
elif args.cache_none:
|
elif args.cache_none:
|
||||||
cache_type = execution.CacheType.DEPENDENCY_AWARE
|
cache_type = execution.CacheType.NONE
|
||||||
|
|
||||||
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
|
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
|
||||||
last_gc_collect = 0
|
last_gc_collect = 0
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.3.65"
|
version = "0.3.66"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@ -35,6 +35,7 @@ from comfy_api.internal import _ComfyNodeInternal
|
|||||||
from app.user_manager import UserManager
|
from app.user_manager import UserManager
|
||||||
from app.model_manager import ModelFileManager
|
from app.model_manager import ModelFileManager
|
||||||
from app.custom_node_manager import CustomNodeManager
|
from app.custom_node_manager import CustomNodeManager
|
||||||
|
from app.subgraph_manager import SubgraphManager
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||||
from protocol import BinaryEventTypes
|
from protocol import BinaryEventTypes
|
||||||
@ -176,6 +177,7 @@ class PromptServer():
|
|||||||
self.user_manager = UserManager()
|
self.user_manager = UserManager()
|
||||||
self.model_file_manager = ModelFileManager()
|
self.model_file_manager = ModelFileManager()
|
||||||
self.custom_node_manager = CustomNodeManager()
|
self.custom_node_manager = CustomNodeManager()
|
||||||
|
self.subgraph_manager = SubgraphManager()
|
||||||
self.internal_routes = InternalRoutes(self)
|
self.internal_routes = InternalRoutes(self)
|
||||||
self.supports = ["custom_nodes_from_web"]
|
self.supports = ["custom_nodes_from_web"]
|
||||||
self.prompt_queue = execution.PromptQueue(self)
|
self.prompt_queue = execution.PromptQueue(self)
|
||||||
@ -825,6 +827,7 @@ class PromptServer():
|
|||||||
self.user_manager.add_routes(self.routes)
|
self.user_manager.add_routes(self.routes)
|
||||||
self.model_file_manager.add_routes(self.routes)
|
self.model_file_manager.add_routes(self.routes)
|
||||||
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
|
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
|
||||||
|
self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items())
|
||||||
self.app.add_subapp('/internal', self.internal_routes.get_app())
|
self.app.add_subapp('/internal', self.internal_routes.get_app())
|
||||||
|
|
||||||
# Prefix every route with /api for easier matching for delegation.
|
# Prefix every route with /api for easier matching for delegation.
|
||||||
|
|||||||
@ -152,12 +152,12 @@ class TestExecution:
|
|||||||
# Initialize server and client
|
# Initialize server and client
|
||||||
#
|
#
|
||||||
@fixture(scope="class", autouse=True, params=[
|
@fixture(scope="class", autouse=True, params=[
|
||||||
# (use_lru, lru_size)
|
{ "extra_args" : [], "should_cache_results" : True },
|
||||||
(False, 0),
|
{ "extra_args" : ["--cache-lru", 0], "should_cache_results" : True },
|
||||||
(True, 0),
|
{ "extra_args" : ["--cache-lru", 100], "should_cache_results" : True },
|
||||||
(True, 100),
|
{ "extra_args" : ["--cache-none"], "should_cache_results" : False },
|
||||||
])
|
])
|
||||||
def _server(self, args_pytest, request):
|
def server(self, args_pytest, request):
|
||||||
# Start server
|
# Start server
|
||||||
pargs = [
|
pargs = [
|
||||||
'python','main.py',
|
'python','main.py',
|
||||||
@ -167,12 +167,10 @@ class TestExecution:
|
|||||||
'--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml',
|
'--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml',
|
||||||
'--cpu',
|
'--cpu',
|
||||||
]
|
]
|
||||||
use_lru, lru_size = request.param
|
pargs += [ str(param) for param in request.param["extra_args"] ]
|
||||||
if use_lru:
|
|
||||||
pargs += ['--cache-lru', str(lru_size)]
|
|
||||||
print("Running server with args:", pargs) # noqa: T201
|
print("Running server with args:", pargs) # noqa: T201
|
||||||
p = subprocess.Popen(pargs)
|
p = subprocess.Popen(pargs)
|
||||||
yield
|
yield request.param
|
||||||
p.kill()
|
p.kill()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@ -193,7 +191,7 @@ class TestExecution:
|
|||||||
return comfy_client
|
return comfy_client
|
||||||
|
|
||||||
@fixture(scope="class", autouse=True)
|
@fixture(scope="class", autouse=True)
|
||||||
def shared_client(self, args_pytest, _server):
|
def shared_client(self, args_pytest, server):
|
||||||
client = self.start_client(args_pytest["listen"], args_pytest["port"])
|
client = self.start_client(args_pytest["listen"], args_pytest["port"])
|
||||||
yield client
|
yield client
|
||||||
del client
|
del client
|
||||||
@ -225,7 +223,7 @@ class TestExecution:
|
|||||||
assert result.did_run(mask)
|
assert result.did_run(mask)
|
||||||
assert result.did_run(lazy_mix)
|
assert result.did_run(lazy_mix)
|
||||||
|
|
||||||
def test_full_cache(self, client: ComfyClient, builder: GraphBuilder):
|
def test_full_cache(self, client: ComfyClient, builder: GraphBuilder, server):
|
||||||
g = builder
|
g = builder
|
||||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
||||||
@ -237,9 +235,12 @@ class TestExecution:
|
|||||||
client.run(g)
|
client.run(g)
|
||||||
result2 = client.run(g)
|
result2 = client.run(g)
|
||||||
for node_id, node in g.nodes.items():
|
for node_id, node in g.nodes.items():
|
||||||
|
if server["should_cache_results"]:
|
||||||
assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached"
|
assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached"
|
||||||
|
else:
|
||||||
|
assert result2.did_run(node), f"Node {node_id} was cached, but should have been run"
|
||||||
|
|
||||||
def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder):
|
def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder, server):
|
||||||
g = builder
|
g = builder
|
||||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
||||||
@ -251,8 +252,12 @@ class TestExecution:
|
|||||||
client.run(g)
|
client.run(g)
|
||||||
mask.inputs['value'] = 0.4
|
mask.inputs['value'] = 0.4
|
||||||
result2 = client.run(g)
|
result2 = client.run(g)
|
||||||
|
if server["should_cache_results"]:
|
||||||
assert not result2.did_run(input1), "Input1 should have been cached"
|
assert not result2.did_run(input1), "Input1 should have been cached"
|
||||||
assert not result2.did_run(input2), "Input2 should have been cached"
|
assert not result2.did_run(input2), "Input2 should have been cached"
|
||||||
|
else:
|
||||||
|
assert result2.did_run(input1), "Input1 should have been rerun"
|
||||||
|
assert result2.did_run(input2), "Input2 should have been rerun"
|
||||||
|
|
||||||
def test_error(self, client: ComfyClient, builder: GraphBuilder):
|
def test_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
@ -411,7 +416,7 @@ class TestExecution:
|
|||||||
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
|
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
client.run(g)
|
client.run(g)
|
||||||
|
|
||||||
def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder):
|
def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder, server):
|
||||||
g = builder
|
g = builder
|
||||||
# Creating the nodes in this specific order previously caused a bug
|
# Creating the nodes in this specific order previously caused a bug
|
||||||
save = g.node("SaveImage")
|
save = g.node("SaveImage")
|
||||||
@ -427,7 +432,10 @@ class TestExecution:
|
|||||||
result3 = client.run(g)
|
result3 = client.run(g)
|
||||||
result4 = client.run(g)
|
result4 = client.run(g)
|
||||||
assert result1.did_run(is_changed), "is_changed should have been run"
|
assert result1.did_run(is_changed), "is_changed should have been run"
|
||||||
|
if server["should_cache_results"]:
|
||||||
assert not result2.did_run(is_changed), "is_changed should have been cached"
|
assert not result2.did_run(is_changed), "is_changed should have been cached"
|
||||||
|
else:
|
||||||
|
assert result2.did_run(is_changed), "is_changed should have been re-run"
|
||||||
assert result3.did_run(is_changed), "is_changed should have been re-run"
|
assert result3.did_run(is_changed), "is_changed should have been re-run"
|
||||||
assert result4.did_run(is_changed), "is_changed should not have been cached"
|
assert result4.did_run(is_changed), "is_changed should not have been cached"
|
||||||
|
|
||||||
@ -514,7 +522,7 @@ class TestExecution:
|
|||||||
assert len(images2) == 1, "Should have 1 image"
|
assert len(images2) == 1, "Should have 1 image"
|
||||||
|
|
||||||
# This tests that only constant outputs are used in the call to `IS_CHANGED`
|
# This tests that only constant outputs are used in the call to `IS_CHANGED`
|
||||||
def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder):
|
def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder, server):
|
||||||
g = builder
|
g = builder
|
||||||
input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1)
|
input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1)
|
||||||
test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5)
|
test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5)
|
||||||
@ -530,7 +538,11 @@ class TestExecution:
|
|||||||
images = result.get_images(output)
|
images = result.get_images(output)
|
||||||
assert len(images) == 1, "Should have 1 image"
|
assert len(images) == 1, "Should have 1 image"
|
||||||
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
||||||
|
if server["should_cache_results"]:
|
||||||
assert not result.did_run(test_node), "The execution should have been cached"
|
assert not result.did_run(test_node), "The execution should have been cached"
|
||||||
|
else:
|
||||||
|
assert result.did_run(test_node), "The execution should have been re-run"
|
||||||
|
|
||||||
|
|
||||||
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
|
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
|
||||||
# Warmup execution to ensure server is fully initialized
|
# Warmup execution to ensure server is fully initialized
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user