mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-08 13:20:50 +08:00
Improvements to compatibility with custom nodes, distributed
backends and other changes - remove uv.lock since it will not be used in most cases for installation - add cli args to prevent some custom nodes from installing packages at runtime - temp directories can now be shared between workers without being deleted - propcache yanked is now in the dependencies - fix configuration arguments loading in some tests
This commit is contained in:
parent
d9e3ba4bec
commit
98ae55b059
@ -31,10 +31,12 @@ class LogInterceptor(io.TextIOWrapper):
|
||||
if isinstance(data, str) and data.startswith("\r") and len(logs) > 0 and not logs[-1]["m"].endswith("\n"):
|
||||
logs.pop()
|
||||
logs.append(entry)
|
||||
super().write(data)
|
||||
if not self.closed:
|
||||
super().write(data)
|
||||
|
||||
def flush(self):
|
||||
super().flush()
|
||||
if not self.closed:
|
||||
super().flush()
|
||||
for cb in self._flush_callbacks:
|
||||
cb(self._logs_since_flush)
|
||||
self._logs_since_flush = []
|
||||
@ -56,7 +58,8 @@ def on_flush(callback):
|
||||
|
||||
class StackTraceLogger(logging.Logger):
|
||||
def _log(self, level, msg, args, exc_info=None, extra=None, stack_info=False, stacklevel=1):
|
||||
if level >= logging.ERROR:
|
||||
if not stack_info and level >= logging.ERROR and exc_info is None:
|
||||
# create a stack even when there is no exception
|
||||
stack_info = True
|
||||
super()._log(level, msg, args, exc_info, extra, stack_info, stacklevel=stacklevel + 1)
|
||||
|
||||
|
||||
@ -273,6 +273,12 @@ def _create_parser() -> EnhancedConfigArgParser:
|
||||
help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--block-runtime-package-installation",
|
||||
action="store_true",
|
||||
help="When set, custom nodes like ComfyUI Manager, Easy Use, Nunchaku and others will not be able to use pip or uv to install packages at runtime (experimental)."
|
||||
)
|
||||
|
||||
default_db_url = db_config()
|
||||
parser.add_argument("--database-url", type=str, default=default_db_url, help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||
parser.add_argument("--workflows", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Execute the API workflow(s) specified in the provided files. For each workflow, its outputs will be printed to a line to standard out. Application logging will be redirected to standard error. Use `-` to signify standard in.")
|
||||
|
||||
@ -168,6 +168,7 @@ class Configuration(dict):
|
||||
blacklist_custom_nodes (list[str]): Specify custom node folders to never load. Accepts shell-style globs.
|
||||
whitelist_custom_nodes (list[str]): Specify custom node folders to load even when --disable-all-custom-nodes is enabled.
|
||||
default_device (Optional[int]): Set the id of the default device, all other devices will stay visible.
|
||||
block_runtime_package_installation (Optional[bool]): When set, custom nodes like ComfyUI Manager, Easy Use, Nunchaku and others will not be able to use pip or uv to install packages at runtime (experimental).
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@ -286,6 +287,7 @@ class Configuration(dict):
|
||||
self.comfy_api_base: str = "https://api.comfy.org"
|
||||
self.database_url: str = db_config()
|
||||
self.default_device: Optional[int] = None
|
||||
self.block_runtime_package_installation = None
|
||||
|
||||
for key, value in kwargs.items():
|
||||
self[key] = value
|
||||
@ -315,6 +317,8 @@ class Configuration(dict):
|
||||
super().update(__m, **kwargs)
|
||||
for k, v in changes.items():
|
||||
self._notify_observers(k, v)
|
||||
# make this more pythonic
|
||||
return self
|
||||
|
||||
def register_observer(self, observer: ConfigObserver):
|
||||
self._observers.append(observer)
|
||||
|
||||
@ -19,19 +19,23 @@ from typing import List, Optional, Tuple, Literal
|
||||
# order matters
|
||||
from .main_pre import tracer
|
||||
import torch
|
||||
from frozendict import frozendict
|
||||
from comfy_execution.graph_types import FrozenTopologicalSort, Input
|
||||
from opentelemetry.trace import get_current_span, StatusCode, Status
|
||||
|
||||
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, \
|
||||
make_locked_method_func
|
||||
from comfy_api.latest import io
|
||||
from comfy_compatibility.vanilla import vanilla_environment_node_execution_hooks
|
||||
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, \
|
||||
DependencyAwareCache, \
|
||||
BasicCache
|
||||
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||
from comfy_execution.graph_types import FrozenTopologicalSort
|
||||
from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, \
|
||||
WebUIProgressHandler, \
|
||||
ProgressRegistry
|
||||
from comfy_execution.utils import CurrentNodeContext
|
||||
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
||||
from comfy_api.latest import io
|
||||
from ..execution_context import current_execution_context, context_set_execution_list_and_inputs
|
||||
from comfy_execution.validation import validate_node_input
|
||||
from .. import interruption
|
||||
from .. import model_management
|
||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||
@ -44,13 +48,11 @@ from ..component_model.module_property import create_module_properties
|
||||
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus, \
|
||||
ExecutionStatusAsDict
|
||||
from ..execution_context import context_execute_node, context_execute_prompt
|
||||
from ..execution_context import current_execution_context, context_set_execution_list_and_inputs
|
||||
from ..execution_ext import should_panic_on_exception
|
||||
from ..node_requests_caching import use_requests_caching
|
||||
from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
|
||||
from ..nodes_context import get_nodes, vanilla_node_execution_environment
|
||||
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler, \
|
||||
ProgressRegistry
|
||||
from comfy_execution.validation import validate_node_input
|
||||
from ..nodes_context import get_nodes
|
||||
|
||||
_module_properties = create_module_properties()
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -474,9 +476,11 @@ async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, ca
|
||||
:param pending_subgraph_results:
|
||||
:return:
|
||||
"""
|
||||
with (context_execute_node(node_id),
|
||||
vanilla_node_execution_environment(),
|
||||
use_requests_caching()):
|
||||
with (
|
||||
context_execute_node(node_id),
|
||||
vanilla_environment_node_execution_hooks(),
|
||||
use_requests_caching(),
|
||||
):
|
||||
return await _execute(server, dynprompt, caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
|
||||
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import contextvars
|
||||
import gc
|
||||
import itertools
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
@ -11,8 +11,10 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
# main_pre must be the earliest import
|
||||
from .main_pre import args
|
||||
|
||||
from .main_pre import tracer
|
||||
from ..cli_args_types import Configuration
|
||||
from ..component_model.file_counter import cleanup_temp as fc_cleanup_temp
|
||||
from ..execution_context import current_execution_context
|
||||
from . import hook_breaker_ac10a0
|
||||
from .extra_model_paths import load_extra_path_config
|
||||
from .. import model_management
|
||||
@ -51,6 +53,7 @@ async def _prompt_worker(q: AbstractPromptQueue, server_instance: server_module.
|
||||
from ..cmd import execution
|
||||
from ..component_model import queue_types
|
||||
from .. import model_management
|
||||
args = current_execution_context().configuration
|
||||
cache_type = execution.CacheType.CLASSIC
|
||||
if args.cache_lru > 0:
|
||||
cache_type = execution.CacheType.LRU
|
||||
@ -147,10 +150,14 @@ def setup_database():
|
||||
init_db()
|
||||
|
||||
|
||||
async def _start_comfyui(from_script_dir: Optional[Path] = None):
|
||||
async def _start_comfyui(from_script_dir: Optional[Path] = None, configuration: Optional[Configuration] = None):
|
||||
from ..execution_context import context_configuration
|
||||
from ..cli_args import cli_args_configuration
|
||||
with context_configuration(cli_args_configuration()):
|
||||
configuration = configuration or cli_args_configuration()
|
||||
with (
|
||||
context_configuration(configuration),
|
||||
fc_cleanup_temp()
|
||||
):
|
||||
await __start_comfyui(from_script_dir=from_script_dir)
|
||||
|
||||
|
||||
@ -159,6 +166,7 @@ async def __start_comfyui(from_script_dir: Optional[Path] = None):
|
||||
Runs ComfyUI's frontend and backend like upstream.
|
||||
:param from_script_dir: when set to a path, assumes that you are running ComfyUI's legacy main.py entrypoint at the root of the git repository located at the path
|
||||
"""
|
||||
args = current_execution_context().configuration
|
||||
if not from_script_dir:
|
||||
os_getcwd = os.getcwd()
|
||||
else:
|
||||
@ -168,7 +176,6 @@ async def __start_comfyui(from_script_dir: Optional[Path] = None):
|
||||
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
|
||||
logger.debug(f"Setting temp directory to: {temp_dir}")
|
||||
folder_paths.set_temp_directory(temp_dir)
|
||||
cleanup_temp()
|
||||
|
||||
if args.user_directory:
|
||||
user_dir = os.path.abspath(args.user_directory)
|
||||
@ -305,7 +312,6 @@ async def __start_comfyui(from_script_dir: Optional[Path] = None):
|
||||
finally:
|
||||
if distributed:
|
||||
await q.close()
|
||||
cleanup_temp()
|
||||
|
||||
|
||||
def entrypoint():
|
||||
|
||||
90
comfy/component_model/file_counter.py
Normal file
90
comfy/component_model/file_counter.py
Normal file
@ -0,0 +1,90 @@
|
||||
import shutil
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import filelock
|
||||
|
||||
|
||||
class ContextWrapper:
|
||||
"""A wrapper to hold context manager values for entry and exit."""
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
self.ctr = None
|
||||
|
||||
def __int__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class FileCounter:
|
||||
def __init__(self, path):
|
||||
self.path = Path(path)
|
||||
|
||||
async def __aenter__(self):
|
||||
wrapper = ContextWrapper(self.get_and_increment())
|
||||
self._context_wrapper = wrapper
|
||||
return wrapper
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
self._context_wrapper.ctr = self.decrement_and_get()
|
||||
|
||||
def __enter__(self):
|
||||
"""Increment on entering the context and return a wrapper."""
|
||||
wrapper = ContextWrapper(self.get_and_increment())
|
||||
self._context_wrapper = wrapper
|
||||
return wrapper
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Decrement on exiting the context and update the wrapper."""
|
||||
self._context_wrapper.ctr = self.decrement_and_get()
|
||||
|
||||
def _read_and_write(self, operation):
|
||||
lock = filelock.FileLock(f"{self.path}.lock")
|
||||
with lock:
|
||||
count = 0
|
||||
try:
|
||||
with open(self.path, 'r') as f:
|
||||
content = f.read().strip()
|
||||
if content:
|
||||
count = int(content)
|
||||
except FileNotFoundError:
|
||||
# File doesn't exist, will be created with initial value.
|
||||
pass
|
||||
except ValueError:
|
||||
# File is corrupt or empty, treat as 0 and overwrite.
|
||||
pass
|
||||
|
||||
original_count = count
|
||||
new_count = operation(count)
|
||||
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(self.path, 'w') as f:
|
||||
f.write(str(new_count))
|
||||
|
||||
return original_count, new_count
|
||||
|
||||
def get_and_increment(self):
|
||||
"""Atomically reads the current value, increments it, and returns the original value."""
|
||||
original_count, _ = self._read_and_write(lambda x: x + 1)
|
||||
return original_count
|
||||
|
||||
def decrement_and_get(self):
|
||||
"""Atomically decrements the value and returns the new value."""
|
||||
_, new_count = self._read_and_write(lambda x: x - 1)
|
||||
return new_count
|
||||
|
||||
|
||||
@contextmanager
|
||||
def cleanup_temp():
|
||||
from ..cli_args import args
|
||||
from ..cmd import folder_paths
|
||||
tmp_dir = Path(args.temp_directory or folder_paths.get_temp_directory())
|
||||
counter_path = tmp_dir / "counter.txt"
|
||||
fc_i = -1
|
||||
try:
|
||||
with FileCounter(counter_path) as fc:
|
||||
yield
|
||||
fc_i = fc.ctr
|
||||
finally:
|
||||
if fc_i == 0 and tmp_dir.is_dir():
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
|
||||
from ..cmd.main_pre import args
|
||||
from ..component_model.file_counter import cleanup_temp
|
||||
from ..component_model.entrypoints_common import configure_application_paths, executor_from_args
|
||||
|
||||
|
||||
@ -10,19 +11,27 @@ async def main():
|
||||
|
||||
args.distributed_queue_worker = True
|
||||
args.distributed_queue_frontend = False
|
||||
|
||||
# in workers, there is a different default
|
||||
if args.block_runtime_package_installation is None:
|
||||
args.block_runtime_package_installation = True
|
||||
|
||||
assert args.distributed_queue_connection_uri is not None, "Set the --distributed-queue-connection-uri argument to your RabbitMQ server"
|
||||
|
||||
configure_application_paths(args)
|
||||
executor = await executor_from_args(args)
|
||||
|
||||
async with DistributedPromptWorker(connection_uri=args.distributed_queue_connection_uri,
|
||||
queue_name=args.distributed_queue_name,
|
||||
executor=executor):
|
||||
stop = asyncio.Event()
|
||||
try:
|
||||
await stop.wait()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
async with (
|
||||
DistributedPromptWorker(connection_uri=args.distributed_queue_connection_uri,
|
||||
queue_name=args.distributed_queue_name,
|
||||
executor=executor),
|
||||
):
|
||||
with cleanup_temp():
|
||||
stop = asyncio.Event()
|
||||
try:
|
||||
await stop.wait()
|
||||
except (asyncio.CancelledError, InterruptedError, KeyboardInterrupt):
|
||||
pass
|
||||
|
||||
|
||||
def entrypoint():
|
||||
|
||||
@ -8,18 +8,19 @@ import os
|
||||
import sys
|
||||
import time
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from os.path import join, basename, dirname, isdir, isfile, exists, abspath, split, splitext, realpath
|
||||
from typing import Iterable, Any, Generator
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from comfy_compatibility.vanilla import prepare_vanilla_environment
|
||||
from comfy_compatibility.vanilla import prepare_vanilla_environment, patch_pip_install_subprocess_run, patch_pip_install_popen
|
||||
from . import base_nodes
|
||||
from .comfyui_v3_package_imports import _comfy_entrypoint_upstream_v3_imports
|
||||
from .package_typing import ExportedNodes
|
||||
from ..cmd import folder_paths
|
||||
from ..component_model.plugins import prompt_server_instance_routes
|
||||
from ..distributed.server_stub import ServerStub
|
||||
from ..execution_context import current_execution_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -44,6 +45,10 @@ class StreamToLogger:
|
||||
# The logger handles its own flushing, so this can be a no-op.
|
||||
pass
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return "utf-8"
|
||||
|
||||
|
||||
class _PromptServerStub(ServerStub):
|
||||
def __init__(self):
|
||||
@ -140,6 +145,7 @@ def _exec_mitigations(module: types.ModuleType, module_path: str) -> Generator[E
|
||||
"comfyui-manager",
|
||||
"comfyui_ryanonyheinside",
|
||||
"comfyui-easy-use",
|
||||
"comfyui_custom_nodes_alekpet",
|
||||
):
|
||||
from ..cmd import folder_paths
|
||||
old_file = folder_paths.__file__
|
||||
@ -147,9 +153,15 @@ def _exec_mitigations(module: types.ModuleType, module_path: str) -> Generator[E
|
||||
try:
|
||||
# mitigate path
|
||||
new_path = join(abspath(join(dirname(old_file), "..", "..")), basename(old_file))
|
||||
config = current_execution_context()
|
||||
|
||||
with patch.object(folder_paths, "__file__", new_path), \
|
||||
patch.object(sys.modules['nodes'], "EXTENSION_WEB_DIRS", {}, create=True): # mitigate JS copy
|
||||
block_installation = config and config.configuration and config.configuration.block_runtime_package_installation
|
||||
with (
|
||||
patch.object(folder_paths, "__file__", new_path),
|
||||
# mitigate packages installing things dynamically
|
||||
patch_pip_install_subprocess_run() if block_installation else nullcontext(),
|
||||
patch_pip_install_popen() if block_installation else nullcontext(),
|
||||
):
|
||||
yield ExportedNodes()
|
||||
finally:
|
||||
# todo: mitigate "/manager/reboot"
|
||||
@ -263,6 +275,8 @@ def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes:
|
||||
# this mitigation puts files that custom nodes expects are at the root of the repository back where they should be
|
||||
# found. we're in the middle of executing the import of execution and server, in all likelihood, so like all things,
|
||||
# the way community custom nodes is pretty radioactive
|
||||
# there's a lot of subtle details here, and unfortunately, once this is called, there are some things that have
|
||||
# to be activated later, in different places, to make all the hacks necessary for custom nodes to work
|
||||
prepare_vanilla_environment()
|
||||
|
||||
from ..cmd import folder_paths
|
||||
|
||||
@ -1,9 +1,5 @@
|
||||
# todo: this should be defined in a common place, the fact that the nodes are imported by execution the way that they are is pretty radioactive
|
||||
import collections.abc
|
||||
import sys
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import patch
|
||||
|
||||
import lazy_object_proxy
|
||||
|
||||
@ -28,29 +24,3 @@ def get_nodes() -> ExportedNodes:
|
||||
if len(current_ctx.custom_nodes) == 0:
|
||||
return nodes
|
||||
return exported_nodes_view(nodes, current_ctx.custom_nodes)
|
||||
|
||||
|
||||
class _NodeClassMappingsShim(collections.abc.Mapping):
|
||||
def __iter__(self):
|
||||
for key in get_nodes().NODE_CLASS_MAPPINGS:
|
||||
yield key
|
||||
|
||||
def __getitem__(self, item):
|
||||
return get_nodes().NODE_CLASS_MAPPINGS[item]
|
||||
|
||||
def __len__(self):
|
||||
return len(get_nodes().NODE_CLASS_MAPPINGS)
|
||||
|
||||
# todo: does this need to be mutable?
|
||||
|
||||
|
||||
@contextmanager
|
||||
def vanilla_node_execution_environment():
|
||||
# check if we're running with patched nodes
|
||||
if 'nodes' in sys.modules:
|
||||
# this ensures NODE_CLASS_MAPPINGS is loaded lazily and contains all the nodes loaded so far, not just the base nodes
|
||||
# easy-use and other nodes expect NODE_CLASS_MAPPINGS to contain all the nodes in the environment
|
||||
with patch('nodes.NODE_CLASS_MAPPINGS', _NodeClassMappingsShim()):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
@ -1,16 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections.abc
|
||||
import contextvars
|
||||
import logging
|
||||
import subprocess
|
||||
import sys
|
||||
import types
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from functools import partial
|
||||
from importlib.util import find_spec
|
||||
from pathlib import Path
|
||||
from threading import RLock
|
||||
from typing import Dict
|
||||
|
||||
import wrapt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# there isn't a way to do this per-thread, it's only per process, so the global is valid
|
||||
# we don't want some kind of multiprocessing lock, because this is munging the sys.modules
|
||||
# wrapt.synchronized will be used to synchronize this
|
||||
_in_environment = False
|
||||
|
||||
|
||||
class _NodeClassMappingsShim(collections.abc.Mapping):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._active = 0
|
||||
self._active_lock = RLock()
|
||||
|
||||
def activate(self):
|
||||
with self._active_lock:
|
||||
self._active += 1
|
||||
|
||||
def deactivate(self):
|
||||
with self._active_lock:
|
||||
self._active -= 1
|
||||
|
||||
def __iter__(self):
|
||||
if self._active > 0:
|
||||
from comfy.nodes_context import get_nodes
|
||||
for key in get_nodes().NODE_CLASS_MAPPINGS:
|
||||
yield key
|
||||
else:
|
||||
from comfy.nodes.base_nodes import NODE_CLASS_MAPPINGS
|
||||
for key in NODE_CLASS_MAPPINGS:
|
||||
yield key
|
||||
|
||||
def __getitem__(self, item):
|
||||
if self._active > 0:
|
||||
from comfy.nodes_context import get_nodes
|
||||
return get_nodes().NODE_CLASS_MAPPINGS[item]
|
||||
else:
|
||||
from comfy.nodes.base_nodes import NODE_CLASS_MAPPINGS
|
||||
return NODE_CLASS_MAPPINGS[item]
|
||||
|
||||
def __len__(self):
|
||||
if self._active > 0:
|
||||
from comfy.nodes_context import get_nodes
|
||||
return len(get_nodes().NODE_CLASS_MAPPINGS)
|
||||
else:
|
||||
from comfy.nodes.base_nodes import NODE_CLASS_MAPPINGS
|
||||
return len(NODE_CLASS_MAPPINGS)
|
||||
|
||||
# todo: does this need to be mutable?
|
||||
|
||||
|
||||
class _NodeShim:
|
||||
def __init__(self):
|
||||
self.__name__ = 'nodes'
|
||||
self.__package__ = ''
|
||||
|
||||
nodes_file = None
|
||||
try:
|
||||
# the 'nodes' module is expected to be in the directory above 'comfy'
|
||||
spec = find_spec('comfy')
|
||||
if spec and spec.origin:
|
||||
comfy_package_path = Path(spec.origin).parent
|
||||
nodes_module_dir = comfy_package_path.parent
|
||||
nodes_file = str(nodes_module_dir / 'nodes.py')
|
||||
except (ImportError, AttributeError):
|
||||
# don't do anything exotic
|
||||
pass
|
||||
|
||||
self.__file__ = nodes_file
|
||||
self.__loader__ = None
|
||||
self.__spec__ = None
|
||||
|
||||
def __node_class_mappings(self) -> _NodeClassMappingsShim:
|
||||
return getattr(self, "NODE_CLASS_MAPPINGS")
|
||||
|
||||
def activate(self):
|
||||
self.__node_class_mappings().activate()
|
||||
|
||||
def deactivate(self):
|
||||
self.__node_class_mappings().deactivate()
|
||||
|
||||
|
||||
@wrapt.synchronized
|
||||
def prepare_vanilla_environment():
|
||||
global _in_environment
|
||||
if _in_environment:
|
||||
@ -38,7 +125,22 @@ def prepare_vanilla_environment():
|
||||
for module in (cuda_malloc, folder_paths, latent_preview, node_helpers, protocol):
|
||||
module_short_name = module.__name__.split(".")[-1]
|
||||
sys.modules[module_short_name] = module
|
||||
sys.modules['nodes'] = base_nodes
|
||||
|
||||
# easy-use needs a shim
|
||||
# this ensures NODE_CLASS_MAPPINGS is loaded lazily and contains all the nodes loaded so far, not just the base nodes
|
||||
# easy-use and other nodes expect NODE_CLASS_MAPPINGS to contain all the nodes in the environment
|
||||
# the shim must be activated after importing, which happens in a tightly coupled way
|
||||
# todo: it's not clear if we should skip the dunder methods or not
|
||||
nodes_shim_dir = {k: getattr(base_nodes, k) for k in dir(base_nodes) if not k.startswith("__")}
|
||||
nodes_shim_dir['NODE_CLASS_MAPPINGS'] = _NodeClassMappingsShim()
|
||||
nodes_shim_dir['EXTENSION_WEB_DIRS'] = {}
|
||||
|
||||
nodes_shim = _NodeShim()
|
||||
for k, v in nodes_shim_dir.items():
|
||||
setattr(nodes_shim, k, v)
|
||||
|
||||
sys.modules['nodes'] = nodes_shim
|
||||
|
||||
comfyui_version = types.ModuleType('comfyui_version', '')
|
||||
setattr(comfyui_version, "__version__", __version__)
|
||||
sys.modules['comfyui_version'] = comfyui_version
|
||||
@ -76,3 +178,97 @@ def prepare_vanilla_environment():
|
||||
threading.Thread.start = patched_start
|
||||
setattr(threading.Thread.start, '__is_patched_by_us', True)
|
||||
logger.debug("Patched `threading.Thread.start` to propagate contextvars.")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def patch_pip_install_subprocess_run():
|
||||
from unittest.mock import patch, MagicMock
|
||||
original_subprocess_run = subprocess.run
|
||||
|
||||
def custom_side_effect(*args, **kwargs):
|
||||
command_list = args[0] if args else []
|
||||
|
||||
# from easy-use
|
||||
is_pip_install_call = (
|
||||
isinstance(command_list, list) and
|
||||
len(command_list) == 6 and
|
||||
command_list[0] == sys.executable and
|
||||
command_list[1] == '-s' and
|
||||
command_list[2] == '-m' and
|
||||
command_list[3] == 'pip' and
|
||||
command_list[4] == 'install' and
|
||||
isinstance(command_list[5], str)
|
||||
)
|
||||
|
||||
if is_pip_install_call:
|
||||
package_name = command_list[5]
|
||||
logger.info(f"Intercepted and mocked `pip install` for: {package_name}")
|
||||
mock_result = MagicMock()
|
||||
mock_result.returncode = 0
|
||||
return mock_result
|
||||
else:
|
||||
return original_subprocess_run(*args, **kwargs)
|
||||
|
||||
with patch('subprocess.run') as mock_run:
|
||||
mock_run.side_effect = custom_side_effect
|
||||
yield
|
||||
|
||||
|
||||
@contextmanager
|
||||
def patch_pip_install_popen():
|
||||
from unittest.mock import patch, MagicMock
|
||||
original_subprocess_popen = subprocess.Popen
|
||||
|
||||
def custom_side_effect(*args, **kwargs):
|
||||
command_list = args[0] if args else []
|
||||
|
||||
is_pip_install_call = (
|
||||
isinstance(command_list, list) and
|
||||
len(command_list) >= 5 and
|
||||
command_list[0] == sys.executable and
|
||||
command_list[1] == "-m" and
|
||||
command_list[2] == "pip" and
|
||||
command_list[3] == "install" and
|
||||
# special case nunchaku
|
||||
"nunchaku" not in command_list[4:]
|
||||
)
|
||||
|
||||
if is_pip_install_call:
|
||||
package_names = command_list[4:]
|
||||
logger.info(f"Intercepted and mocked `subprocess.Popen` for: pip install {' '.join(package_names)}")
|
||||
|
||||
mock_popen_instance = MagicMock()
|
||||
# make stdout and stderr empty iterables so loops over them complete immediately.
|
||||
mock_popen_instance.stdout = []
|
||||
mock_popen_instance.stderr = []
|
||||
|
||||
return mock_popen_instance
|
||||
else:
|
||||
return original_subprocess_popen(*args, **kwargs)
|
||||
|
||||
with patch('subprocess.Popen') as mock_popen:
|
||||
mock_popen.side_effect = custom_side_effect
|
||||
yield mock_popen
|
||||
|
||||
|
||||
@contextmanager
|
||||
def vanilla_environment_node_execution_hooks():
|
||||
# this handles activating the NODE_CLASS_MAPPINGS shim
|
||||
from comfy.execution_context import current_execution_context
|
||||
ctx = current_execution_context()
|
||||
|
||||
if 'nodes' in sys.modules and isinstance(sys.modules['nodes'], _NodeShim):
|
||||
nodes_shim: _NodeShim = sys.modules['nodes']
|
||||
try:
|
||||
nodes_shim.activate()
|
||||
|
||||
block_installs = ctx and ctx.configuration and ctx.configuration.block_runtime_package_installation is True
|
||||
with (
|
||||
patch_pip_install_subprocess_run() if block_installs else nullcontext(),
|
||||
patch_pip_install_popen() if block_installs else nullcontext(),
|
||||
):
|
||||
yield
|
||||
finally:
|
||||
nodes_shim.deactivate()
|
||||
else:
|
||||
yield
|
||||
|
||||
@ -27,7 +27,7 @@ dependencies = [
|
||||
"torchsde>=0.2.6",
|
||||
"einops>=0.6.0",
|
||||
"open-clip-torch>=2.24.0",
|
||||
"transformers>=4.46.0,!=4.53.0,!=4.53.1,!=4.53.2",
|
||||
"transformers>=4.46.0,!=4.53.0,!=4.53.1,!=4.53.2,!=4.57.0",
|
||||
"tokenizers>=0.13.3",
|
||||
"sentencepiece",
|
||||
"peft>=0.10.0",
|
||||
@ -52,7 +52,7 @@ dependencies = [
|
||||
"tqdm",
|
||||
"protobuf>=3.20.0,<5.0.0",
|
||||
"psutil",
|
||||
"ConfigArgParse",
|
||||
"ConfigArgParse>=1.7.1",
|
||||
"aio-pika",
|
||||
"pyjwt[crypto]",
|
||||
"kornia>=0.7.0",
|
||||
@ -113,6 +113,8 @@ dependencies = [
|
||||
"stringzilla<4.2.0",
|
||||
"requests_cache",
|
||||
"universal_pathlib",
|
||||
# yanked propcache is omitted
|
||||
"propcache!=0.4.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
|
||||
@ -13,6 +13,9 @@ from typing import List, Any, Generator
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from comfy.cli_args import default_configuration
|
||||
from comfy.execution_context import context_configuration
|
||||
|
||||
os.environ['OTEL_METRICS_EXPORTER'] = 'none'
|
||||
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
||||
os.environ["HF_XET_HIGH_PERFORMANCE"] = "True"
|
||||
@ -27,11 +30,8 @@ logging.getLogger("aio_pika").setLevel(logging.CRITICAL + 1)
|
||||
|
||||
def run_server(server_arguments: Configuration):
|
||||
from comfy.cmd.main import _start_comfyui
|
||||
from comfy.cli_args import args
|
||||
import asyncio
|
||||
for arg, value in server_arguments.items():
|
||||
args[arg] = value
|
||||
asyncio.run(_start_comfyui())
|
||||
asyncio.run(_start_comfyui(configuration=server_arguments))
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=False)
|
||||
@ -140,7 +140,7 @@ def comfy_background_server(tmp_path_factory) -> Generator[tuple[Configuration,
|
||||
tmp_path = tmp_path_factory.mktemp("comfy_background_server")
|
||||
# Start server
|
||||
|
||||
configuration = Configuration()
|
||||
configuration = default_configuration()
|
||||
configuration.listen = "localhost"
|
||||
configuration.output_directory = str(tmp_path)
|
||||
configuration.input_directory = str(tmp_path)
|
||||
|
||||
288
tests/distributed/test_counter.py
Normal file
288
tests/distributed/test_counter.py
Normal file
@ -0,0 +1,288 @@
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from threading import Thread, Barrier
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
import contextvars
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import pytest
|
||||
from testcontainers.core.container import DockerContainer
|
||||
from testcontainers.core.wait_strategies import LogMessageWaitStrategy
|
||||
|
||||
from comfy.component_model.file_counter import FileCounter
|
||||
from comfy.component_model.folder_path_types import FolderNames
|
||||
from comfy.execution_context import context_folder_names_and_paths
|
||||
from comfy.cmd.folder_paths import init_default_paths
|
||||
|
||||
|
||||
def is_tool(name):
|
||||
"""Check whether `name` is on PATH and marked as executable."""
|
||||
return shutil.which(name) is not None
|
||||
|
||||
|
||||
def run_command(command, check=True):
|
||||
"""Helper to run a shell command."""
|
||||
try:
|
||||
return subprocess.run(command, shell=True, check=check, capture_output=True, text=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Command failed: {command}")
|
||||
print(f"--- STDOUT ---\n{e.stdout}")
|
||||
print(f"--- STDERR ---\n{e.stderr}")
|
||||
print("--------------")
|
||||
raise
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
pytest.param("local", id="local_filesystem"),
|
||||
pytest.param(
|
||||
"nfs",
|
||||
id="nfs_share",
|
||||
marks=pytest.mark.skipif(
|
||||
not sys.platform.startswith("linux")
|
||||
or not is_tool("mount.nfs")
|
||||
or not is_tool("sudo")
|
||||
or not os.path.exists("/sys/module/nfsd"),
|
||||
reason="NFS tests require sudo, nfs-common, and the 'nfsd' kernel module to be loaded.",
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
"samba",
|
||||
id="samba_share",
|
||||
marks=pytest.mark.skipif(
|
||||
not sys.platform.startswith("linux")
|
||||
or not is_tool("mount.cifs")
|
||||
or not is_tool("sudo"),
|
||||
reason="Samba tests require sudo on Linux with cifs-utils installed.",
|
||||
),
|
||||
),
|
||||
]
|
||||
)
|
||||
def counter_path_factory(request, tmp_path_factory):
|
||||
"""A parameterized fixture to provide paths on local, NFS, and Samba filesystems."""
|
||||
if request.param == "local":
|
||||
yield lambda name: str(tmp_path_factory.mktemp("local_test") / name)
|
||||
return
|
||||
|
||||
mount_point = tmp_path_factory.mktemp(f"mount_point_{request.param}")
|
||||
|
||||
if request.param == "nfs":
|
||||
# 1. Create the host directory that will be mounted into the container.
|
||||
nfs_source = tmp_path_factory.mktemp("nfs_source")
|
||||
# 2. FIX: Set permissions on the *host* directory.
|
||||
os.chmod(str(nfs_source), 0o777)
|
||||
|
||||
# 3. FIX: Use the new container's required path: /mnt/data
|
||||
container_path = "/mnt/data"
|
||||
|
||||
# 4. FIX: Change to the new container image and configuration
|
||||
nfs_container = DockerContainer("ghcr.io/normal-computing/nfs-server:latest").with_env(
|
||||
"NFS_SERVER_ALLOWED_CLIENTS", "*"
|
||||
).with_kwargs(privileged=True).with_exposed_ports(2049).with_volume_mapping(
|
||||
str(nfs_source), container_path, mode="rw" # Mount to /mnt/data
|
||||
)
|
||||
|
||||
nfs_container.start()
|
||||
|
||||
# 5. FIX: Wait for the new container's export log message
|
||||
# (and remove the timeout as requested)
|
||||
nfs_container.waiting_for(
|
||||
LogMessageWaitStrategy(r"exporting /mnt/data")
|
||||
)
|
||||
|
||||
request.addfinalizer(lambda: nfs_container.stop())
|
||||
|
||||
ip_address = nfs_container.get_container_host_ip()
|
||||
nfs_port = nfs_container.get_exposed_port(2049)
|
||||
try:
|
||||
# 6. FIX: Mount using the new container's command format
|
||||
# (mounts root ":" and uses "-t nfs4")
|
||||
run_command(f"sleep 1 && sudo mount -t nfs4 -o proto=tcp,port={nfs_port} {ip_address}:/ {mount_point}")
|
||||
yield lambda name: str(mount_point / name)
|
||||
finally:
|
||||
run_command(f"sudo umount {mount_point}", check=False)
|
||||
|
||||
elif request.param == "samba":
|
||||
# 1. Create the host directory.
|
||||
samba_source = tmp_path_factory.mktemp("samba_source")
|
||||
# 2. Set permissions on the *host* directory.
|
||||
os.chmod(str(samba_source), 0o777)
|
||||
|
||||
share_name = "storage"
|
||||
|
||||
# 3. FIX: Add the NAME environment variable to tell the container
|
||||
# to create a share named "storage".
|
||||
samba_container = DockerContainer("dockurr/samba:latest").with_env(
|
||||
"RW", "yes"
|
||||
).with_env(
|
||||
"NAME", share_name # <-- This is the crucial fix
|
||||
).with_exposed_ports(445).with_volume_mapping(
|
||||
str(samba_source), "/storage", mode="rw" # This maps the host dir to the internal /storage path
|
||||
)
|
||||
|
||||
samba_container.start()
|
||||
|
||||
# 4. Wait for the correct log message
|
||||
# (and remove the timeout as requested)
|
||||
samba_container.waiting_for(
|
||||
LogMessageWaitStrategy(r"smbd version .* started")
|
||||
)
|
||||
|
||||
request.addfinalizer(lambda: samba_container.stop())
|
||||
|
||||
ip_address = samba_container.get_container_host_ip()
|
||||
samba_port = samba_container.get_exposed_port(445)
|
||||
try:
|
||||
# 5. FIX: Mount with the default username/password, not as guest.
|
||||
run_command(f"sleep 1 && sudo mount -t cifs -o username=samba,password=secret,vers=3.0,port={samba_port},uid=$(id -u),gid=$(id -g) //{ip_address}/{share_name} {mount_point}", check=True)
|
||||
yield lambda name: str(mount_point / name)
|
||||
finally:
|
||||
run_command(f"sudo umount {mount_point}", check=False)
|
||||
|
||||
|
||||
def test_initial_state(counter_path_factory):
|
||||
"""Verify initial state and file creation."""
|
||||
counter_file = counter_path_factory("counter.txt")
|
||||
lock_file = counter_path_factory("counter.txt.lock")
|
||||
|
||||
assert not os.path.exists(counter_file)
|
||||
assert not os.path.exists(lock_file)
|
||||
|
||||
counter = FileCounter(str(counter_file))
|
||||
assert counter.get_and_increment() == 0
|
||||
assert os.path.exists(counter_file)
|
||||
with open(counter_file, "r") as f:
|
||||
assert f.read() == "1"
|
||||
|
||||
|
||||
def test_mkdirs(counter_path_factory):
|
||||
counter_file = counter_path_factory("new_dir/counter.txt")
|
||||
assert not os.path.exists(os.path.dirname(counter_file))
|
||||
|
||||
counter = FileCounter(str(counter_file))
|
||||
assert counter.get_and_increment() == 0
|
||||
assert counter.get_and_increment() == 1
|
||||
|
||||
|
||||
def test_increment_and_decrement(counter_path_factory):
|
||||
"""Test the increment and decrement logic."""
|
||||
counter_file = counter_path_factory("counter.txt")
|
||||
counter = FileCounter(str(counter_file))
|
||||
|
||||
assert counter.get_and_increment() == 0 # val: 0, new_val: 1
|
||||
assert counter.get_and_increment() == 1 # val: 1, new_val: 2
|
||||
assert counter.get_and_increment() == 2 # val: 2, new_val: 3
|
||||
|
||||
assert counter.decrement_and_get() == 2 # val: 3, new_val: 2
|
||||
assert counter.decrement_and_get() == 1 # val: 2, new_val: 1
|
||||
|
||||
assert counter.get_and_increment() == 1 # val: 1, new_val: 2
|
||||
|
||||
|
||||
def test_multiple_instances_same_path(counter_path_factory):
|
||||
"""Verify that multiple FileCounter instances on the same path work correctly."""
|
||||
counter_file = counter_path_factory("counter.txt")
|
||||
counter1 = FileCounter(str(counter_file))
|
||||
counter2 = FileCounter(str(counter_file))
|
||||
|
||||
assert counter1.get_and_increment() == 0
|
||||
assert counter2.get_and_increment() == 1
|
||||
assert counter1.decrement_and_get() == 1
|
||||
assert counter2.get_and_increment() == 1
|
||||
|
||||
with open(counter_file, "r") as f:
|
||||
assert f.read() == "2"
|
||||
|
||||
|
||||
def test_multithreaded_access(counter_path_factory):
|
||||
"""Ensure atomicity with multiple threads."""
|
||||
counter_file = counter_path_factory("counter.txt")
|
||||
counter = FileCounter(str(counter_file))
|
||||
num_threads = 10
|
||||
increments_per_thread = 100
|
||||
|
||||
def worker():
|
||||
for _ in range(increments_per_thread):
|
||||
counter.get_and_increment()
|
||||
|
||||
threads = [Thread(target=worker) for _ in range(num_threads)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
with open(counter_file, "r") as f:
|
||||
final_value = int(f.read())
|
||||
assert final_value == num_threads * increments_per_thread
|
||||
|
||||
|
||||
def test_context_manager(counter_path_factory):
|
||||
"""Test that the counter can be used as a context manager."""
|
||||
counter_file = counter_path_factory("counter.txt")
|
||||
counter = FileCounter(str(counter_file))
|
||||
|
||||
# Initial state should be 0
|
||||
assert counter.get_and_increment() == 0
|
||||
with open(counter_file) as f:
|
||||
assert f.read() == "1"
|
||||
|
||||
with counter as wrapper:
|
||||
# The wrapper's value is the original count before increment.
|
||||
assert wrapper.value == 1
|
||||
# It can also be used as an integer directly.
|
||||
assert int(wrapper) == 1
|
||||
with open(counter_file) as f:
|
||||
assert f.read() == "2" # Inside context, value is 2
|
||||
|
||||
with open(counter_file) as f:
|
||||
assert f.read() == "1" # Exited context, decremented back to 1
|
||||
# After exit, the wrapper's 'ctr' attribute holds the new value.
|
||||
assert wrapper.ctr == 1
|
||||
|
||||
|
||||
async def test_cleanup_temp_multithreaded(tmp_path):
|
||||
"""
|
||||
Test that cleanup_temp correctly deletes the temp directory only
|
||||
after the last thread has exited the context.
|
||||
"""
|
||||
# 1. Use the application's context to define the temp directory for this test.
|
||||
# This is a cleaner approach than mocking.
|
||||
base_dir = tmp_path / "base"
|
||||
temp_dir_override = base_dir / "temp"
|
||||
fn = FolderNames(base_paths=[base_dir])
|
||||
init_default_paths(fn, base_paths_from_configuration=False)
|
||||
# Override the default temp path
|
||||
fn.temp_directory = temp_dir_override
|
||||
|
||||
from comfy.component_model.file_counter import cleanup_temp
|
||||
|
||||
num_threads = 5
|
||||
# Barrier to make threads wait for each other before exiting.
|
||||
barrier = Barrier(num_threads)
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
def worker():
|
||||
"""The task for each thread. It enters the cleanup context and waits."""
|
||||
with cleanup_temp():
|
||||
# The temp directory and counter file should exist inside the context.
|
||||
assert temp_dir_override.exists()
|
||||
assert (temp_dir_override / "counter.txt").exists()
|
||||
# After exiting, the directory should still exist until the last thread is done.
|
||||
|
||||
# Use the context manager to set the folder paths for the current async context.
|
||||
with context_folder_names_and_paths(fn):
|
||||
# Capture the current context, which now includes the folder_paths settings.
|
||||
|
||||
# Run the worker function in a thread pool, applying the captured context to each task.
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
|
||||
tasks = [loop.run_in_executor(executor, contextvars.copy_context().run, worker) for _ in range(num_threads)]
|
||||
# Wait for all threads to complete their work.
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# After all threads have joined, the counter should be 0, and the directory deleted.
|
||||
assert not temp_dir_override.exists()
|
||||
# the base dir is not going to be deleted
|
||||
assert base_dir.exists()
|
||||
@ -16,6 +16,7 @@ from typing import List, Dict, Any, Generator
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from comfy.cli_args import default_configuration
|
||||
from comfy.cli_args_types import Configuration
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from .test_execution import ComfyClient, RunResult
|
||||
@ -188,7 +189,7 @@ class TestProgressIsolation:
|
||||
tmp_path = tmp_path_factory.mktemp("comfy_background_server")
|
||||
# Start server
|
||||
|
||||
configuration = Configuration()
|
||||
configuration = default_configuration()
|
||||
configuration.listen = args_pytest["listen"]
|
||||
configuration.port = args_pytest["port"]
|
||||
configuration.cpu = True
|
||||
@ -205,7 +206,7 @@ testing_pack:
|
||||
yaml_path = str(tmp_path_factory.mktemp("comfy_background_server") / "extra_nodes.yaml")
|
||||
with open(yaml_path, mode="wt") as f:
|
||||
f.write(extra_nodes)
|
||||
configuration.extra_model_paths_config = [str(yaml_path)]
|
||||
configuration.extra_model_paths_config = [yaml_path]
|
||||
|
||||
yield from comfy_background_server_from_config(configuration)
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import importlib.resources
|
||||
import json
|
||||
import logging
|
||||
from importlib.abc import Traversable
|
||||
|
||||
import pytest
|
||||
@ -11,6 +12,8 @@ from comfy.model_downloader_types import CivitFile, HuggingFile
|
||||
from comfy_extras.nodes.nodes_audio import TorchAudioNotFoundError
|
||||
from . import workflows
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=False)
|
||||
async def client(tmp_path_factory) -> Comfy:
|
||||
@ -35,6 +38,7 @@ async def test_workflow(workflow_name: str, workflow_file: Traversable, has_gpu:
|
||||
|
||||
prompt = Prompt.validate(workflow)
|
||||
# todo: add all the models we want to test a bit m2ore elegantly
|
||||
outputs = {}
|
||||
try:
|
||||
outputs = await client.queue_prompt(prompt)
|
||||
except TorchAudioNotFoundError:
|
||||
@ -46,8 +50,14 @@ async def test_workflow(workflow_name: str, workflow_file: Traversable, has_gpu:
|
||||
elif any(v.class_type == "SaveAudio" for v in prompt.values()):
|
||||
save_audio_node_id = next(key for key in prompt if prompt[key].class_type == "SaveAudio")
|
||||
assert outputs[save_audio_node_id]["audio"][0]["filename"] is not None
|
||||
elif any(v.class_type == "SaveAnimatedWEBP" for v in prompt.values()):
|
||||
save_video_node_id = next(key for key in prompt if prompt[key].class_type == "SaveAnimatedWEBP")
|
||||
assert outputs[save_video_node_id]["images"][0]["filename"] is not None
|
||||
elif any(v.class_type == "PreviewString" for v in prompt.values()):
|
||||
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "PreviewString")
|
||||
output_str = outputs[save_image_node_id]["string"][0]
|
||||
assert output_str is not None
|
||||
assert len(output_str) > 0
|
||||
else:
|
||||
assert len(outputs) > 0
|
||||
logger.warning(f"test {workflow_name} did not have a node that could be checked for output")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user