Improve vanilla node importing

This commit is contained in:
doctorpangloss 2024-03-08 16:29:48 -08:00
parent 54b94e4c87
commit 175a50d7ba
8 changed files with 76 additions and 48 deletions

View File

@ -147,10 +147,6 @@ On macOS, install exactly Python 3.11 using `brew`, which you can download from
```shell
comfyui
```
To generate python OpenAPI models:
```shell
comfyui-openapi-gen
```
To run tests:
```shell
pytest tests/inference

View File

@ -10,21 +10,18 @@ import traceback
import typing
from typing import List, Optional, Tuple, Union
from typing_extensions import TypedDict
# Suppress warnings during import
import warnings
warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.")
import torch
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
from ..component_model.executor_types import ExecutorToClientProgress
from .. import model_management
from ..nodes.package_typing import ExportedNodes
from ..nodes.package import import_all_nodes_in_workspace
nodes = import_all_nodes_in_workspace()
from .. import model_management # type: ignore
# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the
# various functions that are declared here. It should have been a context in the first place.
nodes: ExportedNodes = import_all_nodes_in_workspace()
def get_input_data(inputs, class_def, unique_id, outputs=None, prompt=None, extra_data=None):
if extra_data is None:
@ -88,7 +85,7 @@ def map_node_over_list(obj, input_data_all, func, allow_interrupt=False):
results.append(getattr(obj, func)(**input_data_all))
elif max_len_input == 0:
if allow_interrupt:
nodes.before_node_execution()
model_management.throw_exception_if_processing_interrupted()
results.append(getattr(obj, func)())
else:
for i in range(max_len_input):
@ -141,7 +138,14 @@ def format_value(x):
return str(x)
def recursive_execute(server: ExecutorToClientProgress, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui,
def recursive_execute(server: ExecutorToClientProgress,
prompt,
outputs,
current_item,
extra_data,
executed,
prompt_id,
outputs_ui,
object_storage):
unique_id = current_item
inputs = prompt[unique_id]['inputs']

View File

@ -1,5 +1,8 @@
from .. import options
# Suppress warnings during import
import warnings
warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.")
options.enable_args_parsing()
import os
@ -9,6 +12,8 @@ from ..cmd import cuda_malloc
from ..cmd import folder_paths
from .extra_model_paths import load_extra_path_config
from ..analytics.analytics import initialize_event_tracking
from ..nodes.package import import_all_nodes_in_workspace
import time
@ -77,7 +82,6 @@ if args.deterministic:
from .. import utils
from ..cmd import execution
from ..cmd import server as server_module
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.queue_types import BinaryEventTypes, ExecutionStatus
@ -88,7 +92,9 @@ from ..distributed.server_stub import ServerStub
def prompt_worker(q: AbstractPromptQueue, _server: server_module.PromptServer):
e = execution.PromptExecutor(_server)
from ..cmd.execution import PromptExecutor
e = PromptExecutor(_server)
last_gc_collect = 0
need_gc = False
gc_collect_interval = 10.0
@ -202,10 +208,27 @@ async def main():
except:
pass
# configure extra model paths earlier
try:
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
if os.path.isfile(extra_model_paths_config_path):
load_extra_path_config(extra_model_paths_config_path)
except NameError:
pass
if args.extra_model_paths_config:
for config_path in itertools.chain(*args.extra_model_paths_config):
load_extra_path_config(config_path)
loop = asyncio.get_event_loop()
server = server_module.PromptServer(loop)
if args.external_address is not None:
server.external_address = args.external_address
# at this stage, it's safe to import nodes
server.nodes = import_all_nodes_in_workspace()
# as a side effect, this also populates the nodes for execution
if args.distributed_queue_connection_uri is not None:
distributed = True
q = DistributedPromptQueue(
@ -219,19 +242,10 @@ async def main():
await q.init()
else:
distributed = False
q = execution.PromptQueue(server)
from execution import PromptQueue
q = PromptQueue(server)
server.prompt_queue = q
try:
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
if os.path.isfile(extra_model_paths_config_path):
load_extra_path_config(extra_model_paths_config_path)
except NameError:
pass
if args.extra_model_paths_config:
for config_path in itertools.chain(*args.extra_model_paths_config):
load_extra_path_config(config_path)
server.add_routes()
hijack_progress(server)

View File

@ -34,10 +34,8 @@ from .. import model_management
from ..component_model.executor_types import ExecutorToClientProgress
from ..component_model.file_output_path import file_output_path
from ..nodes.package import import_all_nodes_in_workspace
from ..nodes.package_typing import ExportedNodes
from ..vendor.appdirs import user_data_dir
nodes = import_all_nodes_in_workspace()
from ..app.user_manager import UserManager
@ -114,6 +112,7 @@ class PromptServer(ExecutorToClientProgress):
self.client_id = None
self.on_prompt_handlers = []
self.nodes: ExportedNodes = ExportedNodes()
@routes.get('/ws')
async def websocket_handler(request):
@ -157,7 +156,7 @@ class PromptServer(ExecutorToClientProgress):
glob.escape(self.web_root), 'extensions/**/*.js'), recursive=True)
extensions = list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
for name, dir in self.nodes.EXTENSION_WEB_DIRS.items():
files = glob.glob(os.path.join(glob.escape(dir), '**/*.js'), recursive=True)
extensions.extend(list(map(lambda f: "/extensions/" + quote(
name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files)))
@ -398,7 +397,7 @@ class PromptServer(ExecutorToClientProgress):
return web.json_response(self.get_queue_info())
def node_info(node_class):
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
obj_class = self.nodes.NODE_CLASS_MAPPINGS[node_class]
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['output'] = obj_class.RETURN_TYPES
@ -407,8 +406,8 @@ class PromptServer(ExecutorToClientProgress):
obj_class.RETURN_TYPES)
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
info['name'] = node_class
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[
node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
info['display_name'] = self.nodes.NODE_DISPLAY_NAME_MAPPINGS[
node_class] if node_class in self.nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
info['description'] = obj_class.DESCRIPTION if hasattr(obj_class, 'DESCRIPTION') else ''
info['category'] = 'sd'
if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
@ -423,7 +422,7 @@ class PromptServer(ExecutorToClientProgress):
@routes.get("/object_info")
async def get_object_info(request):
out = {}
for x in nodes.NODE_CLASS_MAPPINGS:
for x in self.nodes.NODE_CLASS_MAPPINGS:
try:
out[x] = node_info(x)
except Exception as e:
@ -436,7 +435,7 @@ class PromptServer(ExecutorToClientProgress):
async def get_object_info_node(request):
node_class = request.match_info.get("node_class", None)
out = {}
if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS):
if (node_class is not None) and (node_class in self.nodes.NODE_CLASS_MAPPINGS):
out[node_class] = node_info(node_class)
return web.json_response(out)
@ -710,7 +709,7 @@ class PromptServer(ExecutorToClientProgress):
self.user_manager.add_routes(self.routes)
self.app.add_routes(self.routes)
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
for name, dir in self.nodes.EXTENSION_WEB_DIRS.items():
self.app.add_routes([
web.static('/extensions/' + quote(name), dir),
])

View File

@ -16,7 +16,7 @@ from . import base_nodes
from .package_typing import ExportedNodes
from .vanilla_node_importing import mitigated_import_of_vanilla_custom_nodes
_comfy_nodes = ExportedNodes()
_comfy_nodes: ExportedNodes = ExportedNodes()
def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleType):
@ -77,7 +77,7 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType, print_import
def import_all_nodes_in_workspace(vanilla_custom_nodes=True) -> ExportedNodes:
global _comfy_nodes
# only load these nodes once
if len(_comfy_nodes) == 0:
base_and_extra = reduce(lambda x, y: x.update(y),
map(_import_and_enumerate_nodes_in_module, [
@ -88,9 +88,6 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True) -> ExportedNodes:
ExportedNodes())
custom_nodes_mappings = ExportedNodes()
if vanilla_custom_nodes:
custom_nodes_mappings += mitigated_import_of_vanilla_custom_nodes()
# load from entrypoints
for entry_point in entry_points().select(group='comfyui.custom_nodes'):
# Load the module associated with the current entry point
@ -100,6 +97,11 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True) -> ExportedNodes:
if isinstance(module, types.ModuleType):
custom_nodes_mappings.update(
_import_and_enumerate_nodes_in_module(module, print_import_times=True))
# load the vanilla custom nodes last
if vanilla_custom_nodes:
custom_nodes_mappings += mitigated_import_of_vanilla_custom_nodes()
# don't allow custom nodes to overwrite base nodes
custom_nodes_mappings -= base_and_extra

View File

@ -124,12 +124,12 @@ class ExportedNodes:
def __sub__(self, other: ExportedNodes):
exported_nodes = ExportedNodes().update(self)
for self_key in exported_nodes.NODE_CLASS_MAPPINGS:
for self_key in frozenset(exported_nodes.NODE_CLASS_MAPPINGS):
if self_key in other.NODE_CLASS_MAPPINGS:
exported_nodes.NODE_CLASS_MAPPINGS.pop(self_key)
if self_key in other.NODE_DISPLAY_NAME_MAPPINGS:
exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.pop(self_key)
for self_key in exported_nodes.EXTENSION_WEB_DIRS:
for self_key in frozenset(exported_nodes.EXTENSION_WEB_DIRS):
if self_key in other.EXTENSION_WEB_DIRS:
exported_nodes.EXTENSION_WEB_DIRS.pop(self_key)
return exported_nodes

View File

@ -89,17 +89,31 @@ def _vanilla_load_custom_nodes_2() -> ExportedNodes:
def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes:
# only vanilla custom nodes will ever go into the custom_nodes directory
# this mitigation puts files that custom nodes expects are at the root of the repository back where they should be found
from ..cmd import cuda_malloc, folder_paths, execution, server, latent_preview
for module in (cuda_malloc, folder_paths, execution, server, latent_preview):
# this mitigation puts files that custom nodes expects are at the root of the repository back where they should be
# found. we're in the middle of executing the import of execution and server, in all likelihood, so like all things,
# the way community custom nodes is pretty radioactive
from ..cmd import cuda_malloc, folder_paths, latent_preview
for module in (cuda_malloc, folder_paths, latent_preview):
module_short_name = module.__name__.split(".")[-1]
sys.modules[module_short_name] = module
sys.modules['nodes'] = base_nodes
from ..cmd import execution, server
for module in (execution, server):
module_short_name = module.__name__.split(".")[-1]
sys.modules[module_short_name] = module
# Impact Pack wants to find model_patcher
from .. import model_patcher
sys.modules['model_patcher'] = model_patcher
comfy_extras_mitigation: Dict[str, types.ModuleType] = {}
import comfy_extras
for module_name, module in sys.modules.items():
if not module_name.startswith("comfy_extras.nodes"):
continue
module_short_name = module_name.split(".")[-1]
setattr(comfy_extras, module_short_name, module)
comfy_extras_mitigation[f'comfy_extras.{module_short_name}'] = module
sys.modules.update(comfy_extras_mitigation)
vanilla_custom_nodes = _vanilla_load_custom_nodes_2()

View File

@ -181,7 +181,6 @@ setup(
setup_requires=["pip", "wheel"],
entry_points={
'console_scripts': [
'comfyui-openapi-gen = comfy.cmd.openapi_gen:main',
'comfyui = comfy.cmd.main:entrypoint',
'comfyui-worker = comfy.cmd.worker:entrypoint'
],