mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Improve vanilla node importing
This commit is contained in:
parent
54b94e4c87
commit
175a50d7ba
@ -147,10 +147,6 @@ On macOS, install exactly Python 3.11 using `brew`, which you can download from
|
||||
```shell
|
||||
comfyui
|
||||
```
|
||||
To generate python OpenAPI models:
|
||||
```shell
|
||||
comfyui-openapi-gen
|
||||
```
|
||||
To run tests:
|
||||
```shell
|
||||
pytest tests/inference
|
||||
|
||||
@ -10,21 +10,18 @@ import traceback
|
||||
import typing
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
# Suppress warnings during import
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.")
|
||||
|
||||
import torch
|
||||
|
||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
|
||||
from ..component_model.executor_types import ExecutorToClientProgress
|
||||
from .. import model_management
|
||||
from ..nodes.package_typing import ExportedNodes
|
||||
from ..nodes.package import import_all_nodes_in_workspace
|
||||
|
||||
nodes = import_all_nodes_in_workspace()
|
||||
from .. import model_management # type: ignore
|
||||
|
||||
# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the
|
||||
# various functions that are declared here. It should have been a context in the first place.
|
||||
nodes: ExportedNodes = import_all_nodes_in_workspace()
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, outputs=None, prompt=None, extra_data=None):
|
||||
if extra_data is None:
|
||||
@ -88,7 +85,7 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
|
||||
results.append(getattr(obj, func)(**input_data_all))
|
||||
elif max_len_input == 0:
|
||||
if allow_interrupt:
|
||||
nodes.before_node_execution()
|
||||
model_management.throw_exception_if_processing_interrupted()
|
||||
results.append(getattr(obj, func)())
|
||||
else:
|
||||
for i in range(max_len_input):
|
||||
@ -141,7 +138,14 @@ def format_value(x):
|
||||
return str(x)
|
||||
|
||||
|
||||
def recursive_execute(server: ExecutorToClientProgress, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui,
|
||||
def recursive_execute(server: ExecutorToClientProgress,
|
||||
prompt,
|
||||
outputs,
|
||||
current_item,
|
||||
extra_data,
|
||||
executed,
|
||||
prompt_id,
|
||||
outputs_ui,
|
||||
object_storage):
|
||||
unique_id = current_item
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
|
||||
@ -1,5 +1,8 @@
|
||||
from .. import options
|
||||
# Suppress warnings during import
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.")
|
||||
options.enable_args_parsing()
|
||||
|
||||
import os
|
||||
@ -9,6 +12,8 @@ from ..cmd import cuda_malloc
|
||||
from ..cmd import folder_paths
|
||||
from .extra_model_paths import load_extra_path_config
|
||||
from ..analytics.analytics import initialize_event_tracking
|
||||
from ..nodes.package import import_all_nodes_in_workspace
|
||||
|
||||
import time
|
||||
|
||||
|
||||
@ -77,7 +82,6 @@ if args.deterministic:
|
||||
|
||||
from .. import utils
|
||||
|
||||
from ..cmd import execution
|
||||
from ..cmd import server as server_module
|
||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||
from ..component_model.queue_types import BinaryEventTypes, ExecutionStatus
|
||||
@ -88,7 +92,9 @@ from ..distributed.server_stub import ServerStub
|
||||
|
||||
|
||||
def prompt_worker(q: AbstractPromptQueue, _server: server_module.PromptServer):
|
||||
e = execution.PromptExecutor(_server)
|
||||
from ..cmd.execution import PromptExecutor
|
||||
|
||||
e = PromptExecutor(_server)
|
||||
last_gc_collect = 0
|
||||
need_gc = False
|
||||
gc_collect_interval = 10.0
|
||||
@ -202,10 +208,27 @@ async def main():
|
||||
except:
|
||||
pass
|
||||
|
||||
# configure extra model paths earlier
|
||||
try:
|
||||
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
||||
if os.path.isfile(extra_model_paths_config_path):
|
||||
load_extra_path_config(extra_model_paths_config_path)
|
||||
except NameError:
|
||||
pass
|
||||
|
||||
if args.extra_model_paths_config:
|
||||
for config_path in itertools.chain(*args.extra_model_paths_config):
|
||||
load_extra_path_config(config_path)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
server = server_module.PromptServer(loop)
|
||||
if args.external_address is not None:
|
||||
server.external_address = args.external_address
|
||||
|
||||
# at this stage, it's safe to import nodes
|
||||
server.nodes = import_all_nodes_in_workspace()
|
||||
# as a side effect, this also populates the nodes for execution
|
||||
|
||||
if args.distributed_queue_connection_uri is not None:
|
||||
distributed = True
|
||||
q = DistributedPromptQueue(
|
||||
@ -219,19 +242,10 @@ async def main():
|
||||
await q.init()
|
||||
else:
|
||||
distributed = False
|
||||
q = execution.PromptQueue(server)
|
||||
from execution import PromptQueue
|
||||
q = PromptQueue(server)
|
||||
server.prompt_queue = q
|
||||
|
||||
try:
|
||||
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
|
||||
if os.path.isfile(extra_model_paths_config_path):
|
||||
load_extra_path_config(extra_model_paths_config_path)
|
||||
except NameError:
|
||||
pass
|
||||
|
||||
if args.extra_model_paths_config:
|
||||
for config_path in itertools.chain(*args.extra_model_paths_config):
|
||||
load_extra_path_config(config_path)
|
||||
|
||||
server.add_routes()
|
||||
hijack_progress(server)
|
||||
|
||||
@ -34,10 +34,8 @@ from .. import model_management
|
||||
from ..component_model.executor_types import ExecutorToClientProgress
|
||||
from ..component_model.file_output_path import file_output_path
|
||||
from ..nodes.package import import_all_nodes_in_workspace
|
||||
from ..nodes.package_typing import ExportedNodes
|
||||
from ..vendor.appdirs import user_data_dir
|
||||
|
||||
nodes = import_all_nodes_in_workspace()
|
||||
|
||||
from ..app.user_manager import UserManager
|
||||
|
||||
|
||||
@ -114,6 +112,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
self.client_id = None
|
||||
|
||||
self.on_prompt_handlers = []
|
||||
self.nodes: ExportedNodes = ExportedNodes()
|
||||
|
||||
@routes.get('/ws')
|
||||
async def websocket_handler(request):
|
||||
@ -157,7 +156,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
glob.escape(self.web_root), 'extensions/**/*.js'), recursive=True)
|
||||
extensions = list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))
|
||||
|
||||
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
|
||||
for name, dir in self.nodes.EXTENSION_WEB_DIRS.items():
|
||||
files = glob.glob(os.path.join(glob.escape(dir), '**/*.js'), recursive=True)
|
||||
extensions.extend(list(map(lambda f: "/extensions/" + quote(
|
||||
name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files)))
|
||||
@ -398,7 +397,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
return web.json_response(self.get_queue_info())
|
||||
|
||||
def node_info(node_class):
|
||||
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
|
||||
obj_class = self.nodes.NODE_CLASS_MAPPINGS[node_class]
|
||||
info = {}
|
||||
info['input'] = obj_class.INPUT_TYPES()
|
||||
info['output'] = obj_class.RETURN_TYPES
|
||||
@ -407,8 +406,8 @@ class PromptServer(ExecutorToClientProgress):
|
||||
obj_class.RETURN_TYPES)
|
||||
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
||||
info['name'] = node_class
|
||||
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[
|
||||
node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
|
||||
info['display_name'] = self.nodes.NODE_DISPLAY_NAME_MAPPINGS[
|
||||
node_class] if node_class in self.nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
|
||||
info['description'] = obj_class.DESCRIPTION if hasattr(obj_class, 'DESCRIPTION') else ''
|
||||
info['category'] = 'sd'
|
||||
if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
|
||||
@ -423,7 +422,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
@routes.get("/object_info")
|
||||
async def get_object_info(request):
|
||||
out = {}
|
||||
for x in nodes.NODE_CLASS_MAPPINGS:
|
||||
for x in self.nodes.NODE_CLASS_MAPPINGS:
|
||||
try:
|
||||
out[x] = node_info(x)
|
||||
except Exception as e:
|
||||
@ -436,7 +435,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
async def get_object_info_node(request):
|
||||
node_class = request.match_info.get("node_class", None)
|
||||
out = {}
|
||||
if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS):
|
||||
if (node_class is not None) and (node_class in self.nodes.NODE_CLASS_MAPPINGS):
|
||||
out[node_class] = node_info(node_class)
|
||||
return web.json_response(out)
|
||||
|
||||
@ -710,7 +709,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
self.user_manager.add_routes(self.routes)
|
||||
self.app.add_routes(self.routes)
|
||||
|
||||
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
|
||||
for name, dir in self.nodes.EXTENSION_WEB_DIRS.items():
|
||||
self.app.add_routes([
|
||||
web.static('/extensions/' + quote(name), dir),
|
||||
])
|
||||
|
||||
@ -16,7 +16,7 @@ from . import base_nodes
|
||||
from .package_typing import ExportedNodes
|
||||
from .vanilla_node_importing import mitigated_import_of_vanilla_custom_nodes
|
||||
|
||||
_comfy_nodes = ExportedNodes()
|
||||
_comfy_nodes: ExportedNodes = ExportedNodes()
|
||||
|
||||
|
||||
def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleType):
|
||||
@ -77,7 +77,7 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType, print_import
|
||||
|
||||
|
||||
def import_all_nodes_in_workspace(vanilla_custom_nodes=True) -> ExportedNodes:
|
||||
global _comfy_nodes
|
||||
# only load these nodes once
|
||||
if len(_comfy_nodes) == 0:
|
||||
base_and_extra = reduce(lambda x, y: x.update(y),
|
||||
map(_import_and_enumerate_nodes_in_module, [
|
||||
@ -88,9 +88,6 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True) -> ExportedNodes:
|
||||
ExportedNodes())
|
||||
custom_nodes_mappings = ExportedNodes()
|
||||
|
||||
if vanilla_custom_nodes:
|
||||
custom_nodes_mappings += mitigated_import_of_vanilla_custom_nodes()
|
||||
|
||||
# load from entrypoints
|
||||
for entry_point in entry_points().select(group='comfyui.custom_nodes'):
|
||||
# Load the module associated with the current entry point
|
||||
@ -100,6 +97,11 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True) -> ExportedNodes:
|
||||
if isinstance(module, types.ModuleType):
|
||||
custom_nodes_mappings.update(
|
||||
_import_and_enumerate_nodes_in_module(module, print_import_times=True))
|
||||
|
||||
# load the vanilla custom nodes last
|
||||
if vanilla_custom_nodes:
|
||||
custom_nodes_mappings += mitigated_import_of_vanilla_custom_nodes()
|
||||
|
||||
# don't allow custom nodes to overwrite base nodes
|
||||
custom_nodes_mappings -= base_and_extra
|
||||
|
||||
|
||||
@ -124,12 +124,12 @@ class ExportedNodes:
|
||||
|
||||
def __sub__(self, other: ExportedNodes):
|
||||
exported_nodes = ExportedNodes().update(self)
|
||||
for self_key in exported_nodes.NODE_CLASS_MAPPINGS:
|
||||
for self_key in frozenset(exported_nodes.NODE_CLASS_MAPPINGS):
|
||||
if self_key in other.NODE_CLASS_MAPPINGS:
|
||||
exported_nodes.NODE_CLASS_MAPPINGS.pop(self_key)
|
||||
if self_key in other.NODE_DISPLAY_NAME_MAPPINGS:
|
||||
exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.pop(self_key)
|
||||
for self_key in exported_nodes.EXTENSION_WEB_DIRS:
|
||||
for self_key in frozenset(exported_nodes.EXTENSION_WEB_DIRS):
|
||||
if self_key in other.EXTENSION_WEB_DIRS:
|
||||
exported_nodes.EXTENSION_WEB_DIRS.pop(self_key)
|
||||
return exported_nodes
|
||||
|
||||
@ -89,17 +89,31 @@ def _vanilla_load_custom_nodes_2() -> ExportedNodes:
|
||||
|
||||
def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes:
|
||||
# only vanilla custom nodes will ever go into the custom_nodes directory
|
||||
# this mitigation puts files that custom nodes expects are at the root of the repository back where they should be found
|
||||
from ..cmd import cuda_malloc, folder_paths, execution, server, latent_preview
|
||||
for module in (cuda_malloc, folder_paths, execution, server, latent_preview):
|
||||
# this mitigation puts files that custom nodes expects are at the root of the repository back where they should be
|
||||
# found. we're in the middle of executing the import of execution and server, in all likelihood, so like all things,
|
||||
# the way community custom nodes is pretty radioactive
|
||||
from ..cmd import cuda_malloc, folder_paths, latent_preview
|
||||
for module in (cuda_malloc, folder_paths, latent_preview):
|
||||
module_short_name = module.__name__.split(".")[-1]
|
||||
sys.modules[module_short_name] = module
|
||||
sys.modules['nodes'] = base_nodes
|
||||
from ..cmd import execution, server
|
||||
for module in (execution, server):
|
||||
module_short_name = module.__name__.split(".")[-1]
|
||||
sys.modules[module_short_name] = module
|
||||
|
||||
# Impact Pack wants to find model_patcher
|
||||
from .. import model_patcher
|
||||
sys.modules['model_patcher'] = model_patcher
|
||||
|
||||
comfy_extras_mitigation: Dict[str, types.ModuleType] = {}
|
||||
|
||||
import comfy_extras
|
||||
for module_name, module in sys.modules.items():
|
||||
if not module_name.startswith("comfy_extras.nodes"):
|
||||
continue
|
||||
module_short_name = module_name.split(".")[-1]
|
||||
setattr(comfy_extras, module_short_name, module)
|
||||
comfy_extras_mitigation[f'comfy_extras.{module_short_name}'] = module
|
||||
sys.modules.update(comfy_extras_mitigation)
|
||||
vanilla_custom_nodes = _vanilla_load_custom_nodes_2()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user