mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-09 13:50:49 +08:00
Improvements to compatibility with custom nodes, distributed
backends and other changes - remove uv.lock since it will not be used in most cases for installation - add cli args to prevent some custom nodes from installing packages at runtime - temp directories can now be shared between workers without being deleted - propcache yanked is now in the dependencies - fix configuration arguments loading in some tests
This commit is contained in:
parent
d9e3ba4bec
commit
98ae55b059
@ -31,10 +31,12 @@ class LogInterceptor(io.TextIOWrapper):
|
|||||||
if isinstance(data, str) and data.startswith("\r") and len(logs) > 0 and not logs[-1]["m"].endswith("\n"):
|
if isinstance(data, str) and data.startswith("\r") and len(logs) > 0 and not logs[-1]["m"].endswith("\n"):
|
||||||
logs.pop()
|
logs.pop()
|
||||||
logs.append(entry)
|
logs.append(entry)
|
||||||
super().write(data)
|
if not self.closed:
|
||||||
|
super().write(data)
|
||||||
|
|
||||||
def flush(self):
|
def flush(self):
|
||||||
super().flush()
|
if not self.closed:
|
||||||
|
super().flush()
|
||||||
for cb in self._flush_callbacks:
|
for cb in self._flush_callbacks:
|
||||||
cb(self._logs_since_flush)
|
cb(self._logs_since_flush)
|
||||||
self._logs_since_flush = []
|
self._logs_since_flush = []
|
||||||
@ -56,7 +58,8 @@ def on_flush(callback):
|
|||||||
|
|
||||||
class StackTraceLogger(logging.Logger):
|
class StackTraceLogger(logging.Logger):
|
||||||
def _log(self, level, msg, args, exc_info=None, extra=None, stack_info=False, stacklevel=1):
|
def _log(self, level, msg, args, exc_info=None, extra=None, stack_info=False, stacklevel=1):
|
||||||
if level >= logging.ERROR:
|
if not stack_info and level >= logging.ERROR and exc_info is None:
|
||||||
|
# create a stack even when there is no exception
|
||||||
stack_info = True
|
stack_info = True
|
||||||
super()._log(level, msg, args, exc_info, extra, stack_info, stacklevel=stacklevel + 1)
|
super()._log(level, msg, args, exc_info, extra, stack_info, stacklevel=stacklevel + 1)
|
||||||
|
|
||||||
|
|||||||
@ -273,6 +273,12 @@ def _create_parser() -> EnhancedConfigArgParser:
|
|||||||
help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
|
help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--block-runtime-package-installation",
|
||||||
|
action="store_true",
|
||||||
|
help="When set, custom nodes like ComfyUI Manager, Easy Use, Nunchaku and others will not be able to use pip or uv to install packages at runtime (experimental)."
|
||||||
|
)
|
||||||
|
|
||||||
default_db_url = db_config()
|
default_db_url = db_config()
|
||||||
parser.add_argument("--database-url", type=str, default=default_db_url, help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
parser.add_argument("--database-url", type=str, default=default_db_url, help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||||
parser.add_argument("--workflows", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Execute the API workflow(s) specified in the provided files. For each workflow, its outputs will be printed to a line to standard out. Application logging will be redirected to standard error. Use `-` to signify standard in.")
|
parser.add_argument("--workflows", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Execute the API workflow(s) specified in the provided files. For each workflow, its outputs will be printed to a line to standard out. Application logging will be redirected to standard error. Use `-` to signify standard in.")
|
||||||
|
|||||||
@ -168,6 +168,7 @@ class Configuration(dict):
|
|||||||
blacklist_custom_nodes (list[str]): Specify custom node folders to never load. Accepts shell-style globs.
|
blacklist_custom_nodes (list[str]): Specify custom node folders to never load. Accepts shell-style globs.
|
||||||
whitelist_custom_nodes (list[str]): Specify custom node folders to load even when --disable-all-custom-nodes is enabled.
|
whitelist_custom_nodes (list[str]): Specify custom node folders to load even when --disable-all-custom-nodes is enabled.
|
||||||
default_device (Optional[int]): Set the id of the default device, all other devices will stay visible.
|
default_device (Optional[int]): Set the id of the default device, all other devices will stay visible.
|
||||||
|
block_runtime_package_installation (Optional[bool]): When set, custom nodes like ComfyUI Manager, Easy Use, Nunchaku and others will not be able to use pip or uv to install packages at runtime (experimental).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
@ -286,6 +287,7 @@ class Configuration(dict):
|
|||||||
self.comfy_api_base: str = "https://api.comfy.org"
|
self.comfy_api_base: str = "https://api.comfy.org"
|
||||||
self.database_url: str = db_config()
|
self.database_url: str = db_config()
|
||||||
self.default_device: Optional[int] = None
|
self.default_device: Optional[int] = None
|
||||||
|
self.block_runtime_package_installation = None
|
||||||
|
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
self[key] = value
|
self[key] = value
|
||||||
@ -315,6 +317,8 @@ class Configuration(dict):
|
|||||||
super().update(__m, **kwargs)
|
super().update(__m, **kwargs)
|
||||||
for k, v in changes.items():
|
for k, v in changes.items():
|
||||||
self._notify_observers(k, v)
|
self._notify_observers(k, v)
|
||||||
|
# make this more pythonic
|
||||||
|
return self
|
||||||
|
|
||||||
def register_observer(self, observer: ConfigObserver):
|
def register_observer(self, observer: ConfigObserver):
|
||||||
self._observers.append(observer)
|
self._observers.append(observer)
|
||||||
|
|||||||
@ -19,19 +19,23 @@ from typing import List, Optional, Tuple, Literal
|
|||||||
# order matters
|
# order matters
|
||||||
from .main_pre import tracer
|
from .main_pre import tracer
|
||||||
import torch
|
import torch
|
||||||
from frozendict import frozendict
|
|
||||||
from comfy_execution.graph_types import FrozenTopologicalSort, Input
|
|
||||||
from opentelemetry.trace import get_current_span, StatusCode, Status
|
from opentelemetry.trace import get_current_span, StatusCode, Status
|
||||||
|
|
||||||
|
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, \
|
||||||
|
make_locked_method_func
|
||||||
|
from comfy_api.latest import io
|
||||||
|
from comfy_compatibility.vanilla import vanilla_environment_node_execution_hooks
|
||||||
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, \
|
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, \
|
||||||
DependencyAwareCache, \
|
DependencyAwareCache, \
|
||||||
BasicCache
|
BasicCache
|
||||||
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||||
|
from comfy_execution.graph_types import FrozenTopologicalSort
|
||||||
from comfy_execution.graph_utils import is_link, GraphBuilder
|
from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||||
|
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, \
|
||||||
|
WebUIProgressHandler, \
|
||||||
|
ProgressRegistry
|
||||||
from comfy_execution.utils import CurrentNodeContext
|
from comfy_execution.utils import CurrentNodeContext
|
||||||
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
from comfy_execution.validation import validate_node_input
|
||||||
from comfy_api.latest import io
|
|
||||||
from ..execution_context import current_execution_context, context_set_execution_list_and_inputs
|
|
||||||
from .. import interruption
|
from .. import interruption
|
||||||
from .. import model_management
|
from .. import model_management
|
||||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||||
@ -44,13 +48,11 @@ from ..component_model.module_property import create_module_properties
|
|||||||
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus, \
|
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus, \
|
||||||
ExecutionStatusAsDict
|
ExecutionStatusAsDict
|
||||||
from ..execution_context import context_execute_node, context_execute_prompt
|
from ..execution_context import context_execute_node, context_execute_prompt
|
||||||
|
from ..execution_context import current_execution_context, context_set_execution_list_and_inputs
|
||||||
from ..execution_ext import should_panic_on_exception
|
from ..execution_ext import should_panic_on_exception
|
||||||
from ..node_requests_caching import use_requests_caching
|
from ..node_requests_caching import use_requests_caching
|
||||||
from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
|
from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
|
||||||
from ..nodes_context import get_nodes, vanilla_node_execution_environment
|
from ..nodes_context import get_nodes
|
||||||
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler, \
|
|
||||||
ProgressRegistry
|
|
||||||
from comfy_execution.validation import validate_node_input
|
|
||||||
|
|
||||||
_module_properties = create_module_properties()
|
_module_properties = create_module_properties()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -474,9 +476,11 @@ async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, ca
|
|||||||
:param pending_subgraph_results:
|
:param pending_subgraph_results:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
with (context_execute_node(node_id),
|
with (
|
||||||
vanilla_node_execution_environment(),
|
context_execute_node(node_id),
|
||||||
use_requests_caching()):
|
vanilla_environment_node_execution_hooks(),
|
||||||
|
use_requests_caching(),
|
||||||
|
):
|
||||||
return await _execute(server, dynprompt, caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
|
return await _execute(server, dynprompt, caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import contextvars
|
import contextvars
|
||||||
import gc
|
import gc
|
||||||
import itertools
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
@ -11,8 +11,10 @@ from pathlib import Path
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
# main_pre must be the earliest import
|
# main_pre must be the earliest import
|
||||||
from .main_pre import args
|
from .main_pre import tracer
|
||||||
|
from ..cli_args_types import Configuration
|
||||||
|
from ..component_model.file_counter import cleanup_temp as fc_cleanup_temp
|
||||||
|
from ..execution_context import current_execution_context
|
||||||
from . import hook_breaker_ac10a0
|
from . import hook_breaker_ac10a0
|
||||||
from .extra_model_paths import load_extra_path_config
|
from .extra_model_paths import load_extra_path_config
|
||||||
from .. import model_management
|
from .. import model_management
|
||||||
@ -51,6 +53,7 @@ async def _prompt_worker(q: AbstractPromptQueue, server_instance: server_module.
|
|||||||
from ..cmd import execution
|
from ..cmd import execution
|
||||||
from ..component_model import queue_types
|
from ..component_model import queue_types
|
||||||
from .. import model_management
|
from .. import model_management
|
||||||
|
args = current_execution_context().configuration
|
||||||
cache_type = execution.CacheType.CLASSIC
|
cache_type = execution.CacheType.CLASSIC
|
||||||
if args.cache_lru > 0:
|
if args.cache_lru > 0:
|
||||||
cache_type = execution.CacheType.LRU
|
cache_type = execution.CacheType.LRU
|
||||||
@ -147,10 +150,14 @@ def setup_database():
|
|||||||
init_db()
|
init_db()
|
||||||
|
|
||||||
|
|
||||||
async def _start_comfyui(from_script_dir: Optional[Path] = None):
|
async def _start_comfyui(from_script_dir: Optional[Path] = None, configuration: Optional[Configuration] = None):
|
||||||
from ..execution_context import context_configuration
|
from ..execution_context import context_configuration
|
||||||
from ..cli_args import cli_args_configuration
|
from ..cli_args import cli_args_configuration
|
||||||
with context_configuration(cli_args_configuration()):
|
configuration = configuration or cli_args_configuration()
|
||||||
|
with (
|
||||||
|
context_configuration(configuration),
|
||||||
|
fc_cleanup_temp()
|
||||||
|
):
|
||||||
await __start_comfyui(from_script_dir=from_script_dir)
|
await __start_comfyui(from_script_dir=from_script_dir)
|
||||||
|
|
||||||
|
|
||||||
@ -159,6 +166,7 @@ async def __start_comfyui(from_script_dir: Optional[Path] = None):
|
|||||||
Runs ComfyUI's frontend and backend like upstream.
|
Runs ComfyUI's frontend and backend like upstream.
|
||||||
:param from_script_dir: when set to a path, assumes that you are running ComfyUI's legacy main.py entrypoint at the root of the git repository located at the path
|
:param from_script_dir: when set to a path, assumes that you are running ComfyUI's legacy main.py entrypoint at the root of the git repository located at the path
|
||||||
"""
|
"""
|
||||||
|
args = current_execution_context().configuration
|
||||||
if not from_script_dir:
|
if not from_script_dir:
|
||||||
os_getcwd = os.getcwd()
|
os_getcwd = os.getcwd()
|
||||||
else:
|
else:
|
||||||
@ -168,7 +176,6 @@ async def __start_comfyui(from_script_dir: Optional[Path] = None):
|
|||||||
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
|
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
|
||||||
logger.debug(f"Setting temp directory to: {temp_dir}")
|
logger.debug(f"Setting temp directory to: {temp_dir}")
|
||||||
folder_paths.set_temp_directory(temp_dir)
|
folder_paths.set_temp_directory(temp_dir)
|
||||||
cleanup_temp()
|
|
||||||
|
|
||||||
if args.user_directory:
|
if args.user_directory:
|
||||||
user_dir = os.path.abspath(args.user_directory)
|
user_dir = os.path.abspath(args.user_directory)
|
||||||
@ -305,7 +312,6 @@ async def __start_comfyui(from_script_dir: Optional[Path] = None):
|
|||||||
finally:
|
finally:
|
||||||
if distributed:
|
if distributed:
|
||||||
await q.close()
|
await q.close()
|
||||||
cleanup_temp()
|
|
||||||
|
|
||||||
|
|
||||||
def entrypoint():
|
def entrypoint():
|
||||||
|
|||||||
90
comfy/component_model/file_counter.py
Normal file
90
comfy/component_model/file_counter.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
import shutil
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import filelock
|
||||||
|
|
||||||
|
|
||||||
|
class ContextWrapper:
|
||||||
|
"""A wrapper to hold context manager values for entry and exit."""
|
||||||
|
|
||||||
|
def __init__(self, value):
|
||||||
|
self.value = value
|
||||||
|
self.ctr = None
|
||||||
|
|
||||||
|
def __int__(self):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
class FileCounter:
|
||||||
|
def __init__(self, path):
|
||||||
|
self.path = Path(path)
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
wrapper = ContextWrapper(self.get_and_increment())
|
||||||
|
self._context_wrapper = wrapper
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self._context_wrapper.ctr = self.decrement_and_get()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""Increment on entering the context and return a wrapper."""
|
||||||
|
wrapper = ContextWrapper(self.get_and_increment())
|
||||||
|
self._context_wrapper = wrapper
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Decrement on exiting the context and update the wrapper."""
|
||||||
|
self._context_wrapper.ctr = self.decrement_and_get()
|
||||||
|
|
||||||
|
def _read_and_write(self, operation):
|
||||||
|
lock = filelock.FileLock(f"{self.path}.lock")
|
||||||
|
with lock:
|
||||||
|
count = 0
|
||||||
|
try:
|
||||||
|
with open(self.path, 'r') as f:
|
||||||
|
content = f.read().strip()
|
||||||
|
if content:
|
||||||
|
count = int(content)
|
||||||
|
except FileNotFoundError:
|
||||||
|
# File doesn't exist, will be created with initial value.
|
||||||
|
pass
|
||||||
|
except ValueError:
|
||||||
|
# File is corrupt or empty, treat as 0 and overwrite.
|
||||||
|
pass
|
||||||
|
|
||||||
|
original_count = count
|
||||||
|
new_count = operation(count)
|
||||||
|
|
||||||
|
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(self.path, 'w') as f:
|
||||||
|
f.write(str(new_count))
|
||||||
|
|
||||||
|
return original_count, new_count
|
||||||
|
|
||||||
|
def get_and_increment(self):
|
||||||
|
"""Atomically reads the current value, increments it, and returns the original value."""
|
||||||
|
original_count, _ = self._read_and_write(lambda x: x + 1)
|
||||||
|
return original_count
|
||||||
|
|
||||||
|
def decrement_and_get(self):
|
||||||
|
"""Atomically decrements the value and returns the new value."""
|
||||||
|
_, new_count = self._read_and_write(lambda x: x - 1)
|
||||||
|
return new_count
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def cleanup_temp():
|
||||||
|
from ..cli_args import args
|
||||||
|
from ..cmd import folder_paths
|
||||||
|
tmp_dir = Path(args.temp_directory or folder_paths.get_temp_directory())
|
||||||
|
counter_path = tmp_dir / "counter.txt"
|
||||||
|
fc_i = -1
|
||||||
|
try:
|
||||||
|
with FileCounter(counter_path) as fc:
|
||||||
|
yield
|
||||||
|
fc_i = fc.ctr
|
||||||
|
finally:
|
||||||
|
if fc_i == 0 and tmp_dir.is_dir():
|
||||||
|
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||||
@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from ..cmd.main_pre import args
|
from ..cmd.main_pre import args
|
||||||
|
from ..component_model.file_counter import cleanup_temp
|
||||||
from ..component_model.entrypoints_common import configure_application_paths, executor_from_args
|
from ..component_model.entrypoints_common import configure_application_paths, executor_from_args
|
||||||
|
|
||||||
|
|
||||||
@ -10,19 +11,27 @@ async def main():
|
|||||||
|
|
||||||
args.distributed_queue_worker = True
|
args.distributed_queue_worker = True
|
||||||
args.distributed_queue_frontend = False
|
args.distributed_queue_frontend = False
|
||||||
|
|
||||||
|
# in workers, there is a different default
|
||||||
|
if args.block_runtime_package_installation is None:
|
||||||
|
args.block_runtime_package_installation = True
|
||||||
|
|
||||||
assert args.distributed_queue_connection_uri is not None, "Set the --distributed-queue-connection-uri argument to your RabbitMQ server"
|
assert args.distributed_queue_connection_uri is not None, "Set the --distributed-queue-connection-uri argument to your RabbitMQ server"
|
||||||
|
|
||||||
configure_application_paths(args)
|
configure_application_paths(args)
|
||||||
executor = await executor_from_args(args)
|
executor = await executor_from_args(args)
|
||||||
|
|
||||||
async with DistributedPromptWorker(connection_uri=args.distributed_queue_connection_uri,
|
async with (
|
||||||
queue_name=args.distributed_queue_name,
|
DistributedPromptWorker(connection_uri=args.distributed_queue_connection_uri,
|
||||||
executor=executor):
|
queue_name=args.distributed_queue_name,
|
||||||
stop = asyncio.Event()
|
executor=executor),
|
||||||
try:
|
):
|
||||||
await stop.wait()
|
with cleanup_temp():
|
||||||
except asyncio.CancelledError:
|
stop = asyncio.Event()
|
||||||
pass
|
try:
|
||||||
|
await stop.wait()
|
||||||
|
except (asyncio.CancelledError, InterruptedError, KeyboardInterrupt):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def entrypoint():
|
def entrypoint():
|
||||||
|
|||||||
@ -8,18 +8,19 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import types
|
import types
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager, nullcontext
|
||||||
from os.path import join, basename, dirname, isdir, isfile, exists, abspath, split, splitext, realpath
|
from os.path import join, basename, dirname, isdir, isfile, exists, abspath, split, splitext, realpath
|
||||||
from typing import Iterable, Any, Generator
|
from typing import Iterable, Any, Generator
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
from comfy_compatibility.vanilla import prepare_vanilla_environment
|
from comfy_compatibility.vanilla import prepare_vanilla_environment, patch_pip_install_subprocess_run, patch_pip_install_popen
|
||||||
from . import base_nodes
|
from . import base_nodes
|
||||||
from .comfyui_v3_package_imports import _comfy_entrypoint_upstream_v3_imports
|
from .comfyui_v3_package_imports import _comfy_entrypoint_upstream_v3_imports
|
||||||
from .package_typing import ExportedNodes
|
from .package_typing import ExportedNodes
|
||||||
from ..cmd import folder_paths
|
from ..cmd import folder_paths
|
||||||
from ..component_model.plugins import prompt_server_instance_routes
|
from ..component_model.plugins import prompt_server_instance_routes
|
||||||
from ..distributed.server_stub import ServerStub
|
from ..distributed.server_stub import ServerStub
|
||||||
|
from ..execution_context import current_execution_context
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -44,6 +45,10 @@ class StreamToLogger:
|
|||||||
# The logger handles its own flushing, so this can be a no-op.
|
# The logger handles its own flushing, so this can be a no-op.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoding(self):
|
||||||
|
return "utf-8"
|
||||||
|
|
||||||
|
|
||||||
class _PromptServerStub(ServerStub):
|
class _PromptServerStub(ServerStub):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -140,6 +145,7 @@ def _exec_mitigations(module: types.ModuleType, module_path: str) -> Generator[E
|
|||||||
"comfyui-manager",
|
"comfyui-manager",
|
||||||
"comfyui_ryanonyheinside",
|
"comfyui_ryanonyheinside",
|
||||||
"comfyui-easy-use",
|
"comfyui-easy-use",
|
||||||
|
"comfyui_custom_nodes_alekpet",
|
||||||
):
|
):
|
||||||
from ..cmd import folder_paths
|
from ..cmd import folder_paths
|
||||||
old_file = folder_paths.__file__
|
old_file = folder_paths.__file__
|
||||||
@ -147,9 +153,15 @@ def _exec_mitigations(module: types.ModuleType, module_path: str) -> Generator[E
|
|||||||
try:
|
try:
|
||||||
# mitigate path
|
# mitigate path
|
||||||
new_path = join(abspath(join(dirname(old_file), "..", "..")), basename(old_file))
|
new_path = join(abspath(join(dirname(old_file), "..", "..")), basename(old_file))
|
||||||
|
config = current_execution_context()
|
||||||
|
|
||||||
with patch.object(folder_paths, "__file__", new_path), \
|
block_installation = config and config.configuration and config.configuration.block_runtime_package_installation
|
||||||
patch.object(sys.modules['nodes'], "EXTENSION_WEB_DIRS", {}, create=True): # mitigate JS copy
|
with (
|
||||||
|
patch.object(folder_paths, "__file__", new_path),
|
||||||
|
# mitigate packages installing things dynamically
|
||||||
|
patch_pip_install_subprocess_run() if block_installation else nullcontext(),
|
||||||
|
patch_pip_install_popen() if block_installation else nullcontext(),
|
||||||
|
):
|
||||||
yield ExportedNodes()
|
yield ExportedNodes()
|
||||||
finally:
|
finally:
|
||||||
# todo: mitigate "/manager/reboot"
|
# todo: mitigate "/manager/reboot"
|
||||||
@ -263,6 +275,8 @@ def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes:
|
|||||||
# this mitigation puts files that custom nodes expects are at the root of the repository back where they should be
|
# this mitigation puts files that custom nodes expects are at the root of the repository back where they should be
|
||||||
# found. we're in the middle of executing the import of execution and server, in all likelihood, so like all things,
|
# found. we're in the middle of executing the import of execution and server, in all likelihood, so like all things,
|
||||||
# the way community custom nodes is pretty radioactive
|
# the way community custom nodes is pretty radioactive
|
||||||
|
# there's a lot of subtle details here, and unfortunately, once this is called, there are some things that have
|
||||||
|
# to be activated later, in different places, to make all the hacks necessary for custom nodes to work
|
||||||
prepare_vanilla_environment()
|
prepare_vanilla_environment()
|
||||||
|
|
||||||
from ..cmd import folder_paths
|
from ..cmd import folder_paths
|
||||||
|
|||||||
@ -1,9 +1,5 @@
|
|||||||
# todo: this should be defined in a common place, the fact that the nodes are imported by execution the way that they are is pretty radioactive
|
# todo: this should be defined in a common place, the fact that the nodes are imported by execution the way that they are is pretty radioactive
|
||||||
import collections.abc
|
|
||||||
import sys
|
|
||||||
import threading
|
import threading
|
||||||
from contextlib import contextmanager
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import lazy_object_proxy
|
import lazy_object_proxy
|
||||||
|
|
||||||
@ -28,29 +24,3 @@ def get_nodes() -> ExportedNodes:
|
|||||||
if len(current_ctx.custom_nodes) == 0:
|
if len(current_ctx.custom_nodes) == 0:
|
||||||
return nodes
|
return nodes
|
||||||
return exported_nodes_view(nodes, current_ctx.custom_nodes)
|
return exported_nodes_view(nodes, current_ctx.custom_nodes)
|
||||||
|
|
||||||
|
|
||||||
class _NodeClassMappingsShim(collections.abc.Mapping):
|
|
||||||
def __iter__(self):
|
|
||||||
for key in get_nodes().NODE_CLASS_MAPPINGS:
|
|
||||||
yield key
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
return get_nodes().NODE_CLASS_MAPPINGS[item]
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(get_nodes().NODE_CLASS_MAPPINGS)
|
|
||||||
|
|
||||||
# todo: does this need to be mutable?
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def vanilla_node_execution_environment():
|
|
||||||
# check if we're running with patched nodes
|
|
||||||
if 'nodes' in sys.modules:
|
|
||||||
# this ensures NODE_CLASS_MAPPINGS is loaded lazily and contains all the nodes loaded so far, not just the base nodes
|
|
||||||
# easy-use and other nodes expect NODE_CLASS_MAPPINGS to contain all the nodes in the environment
|
|
||||||
with patch('nodes.NODE_CLASS_MAPPINGS', _NodeClassMappingsShim()):
|
|
||||||
yield
|
|
||||||
else:
|
|
||||||
yield
|
|
||||||
|
|||||||
@ -1,16 +1,103 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import collections.abc
|
||||||
import contextvars
|
import contextvars
|
||||||
import logging
|
import logging
|
||||||
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import types
|
import types
|
||||||
|
from contextlib import contextmanager, nullcontext
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from importlib.util import find_spec
|
||||||
|
from pathlib import Path
|
||||||
|
from threading import RLock
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
import wrapt
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# there isn't a way to do this per-thread, it's only per process, so the global is valid
|
||||||
|
# we don't want some kind of multiprocessing lock, because this is munging the sys.modules
|
||||||
|
# wrapt.synchronized will be used to synchronize this
|
||||||
_in_environment = False
|
_in_environment = False
|
||||||
|
|
||||||
|
|
||||||
|
class _NodeClassMappingsShim(collections.abc.Mapping):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self._active = 0
|
||||||
|
self._active_lock = RLock()
|
||||||
|
|
||||||
|
def activate(self):
|
||||||
|
with self._active_lock:
|
||||||
|
self._active += 1
|
||||||
|
|
||||||
|
def deactivate(self):
|
||||||
|
with self._active_lock:
|
||||||
|
self._active -= 1
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if self._active > 0:
|
||||||
|
from comfy.nodes_context import get_nodes
|
||||||
|
for key in get_nodes().NODE_CLASS_MAPPINGS:
|
||||||
|
yield key
|
||||||
|
else:
|
||||||
|
from comfy.nodes.base_nodes import NODE_CLASS_MAPPINGS
|
||||||
|
for key in NODE_CLASS_MAPPINGS:
|
||||||
|
yield key
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
if self._active > 0:
|
||||||
|
from comfy.nodes_context import get_nodes
|
||||||
|
return get_nodes().NODE_CLASS_MAPPINGS[item]
|
||||||
|
else:
|
||||||
|
from comfy.nodes.base_nodes import NODE_CLASS_MAPPINGS
|
||||||
|
return NODE_CLASS_MAPPINGS[item]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
if self._active > 0:
|
||||||
|
from comfy.nodes_context import get_nodes
|
||||||
|
return len(get_nodes().NODE_CLASS_MAPPINGS)
|
||||||
|
else:
|
||||||
|
from comfy.nodes.base_nodes import NODE_CLASS_MAPPINGS
|
||||||
|
return len(NODE_CLASS_MAPPINGS)
|
||||||
|
|
||||||
|
# todo: does this need to be mutable?
|
||||||
|
|
||||||
|
|
||||||
|
class _NodeShim:
|
||||||
|
def __init__(self):
|
||||||
|
self.__name__ = 'nodes'
|
||||||
|
self.__package__ = ''
|
||||||
|
|
||||||
|
nodes_file = None
|
||||||
|
try:
|
||||||
|
# the 'nodes' module is expected to be in the directory above 'comfy'
|
||||||
|
spec = find_spec('comfy')
|
||||||
|
if spec and spec.origin:
|
||||||
|
comfy_package_path = Path(spec.origin).parent
|
||||||
|
nodes_module_dir = comfy_package_path.parent
|
||||||
|
nodes_file = str(nodes_module_dir / 'nodes.py')
|
||||||
|
except (ImportError, AttributeError):
|
||||||
|
# don't do anything exotic
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.__file__ = nodes_file
|
||||||
|
self.__loader__ = None
|
||||||
|
self.__spec__ = None
|
||||||
|
|
||||||
|
def __node_class_mappings(self) -> _NodeClassMappingsShim:
|
||||||
|
return getattr(self, "NODE_CLASS_MAPPINGS")
|
||||||
|
|
||||||
|
def activate(self):
|
||||||
|
self.__node_class_mappings().activate()
|
||||||
|
|
||||||
|
def deactivate(self):
|
||||||
|
self.__node_class_mappings().deactivate()
|
||||||
|
|
||||||
|
|
||||||
|
@wrapt.synchronized
|
||||||
def prepare_vanilla_environment():
|
def prepare_vanilla_environment():
|
||||||
global _in_environment
|
global _in_environment
|
||||||
if _in_environment:
|
if _in_environment:
|
||||||
@ -38,7 +125,22 @@ def prepare_vanilla_environment():
|
|||||||
for module in (cuda_malloc, folder_paths, latent_preview, node_helpers, protocol):
|
for module in (cuda_malloc, folder_paths, latent_preview, node_helpers, protocol):
|
||||||
module_short_name = module.__name__.split(".")[-1]
|
module_short_name = module.__name__.split(".")[-1]
|
||||||
sys.modules[module_short_name] = module
|
sys.modules[module_short_name] = module
|
||||||
sys.modules['nodes'] = base_nodes
|
|
||||||
|
# easy-use needs a shim
|
||||||
|
# this ensures NODE_CLASS_MAPPINGS is loaded lazily and contains all the nodes loaded so far, not just the base nodes
|
||||||
|
# easy-use and other nodes expect NODE_CLASS_MAPPINGS to contain all the nodes in the environment
|
||||||
|
# the shim must be activated after importing, which happens in a tightly coupled way
|
||||||
|
# todo: it's not clear if we should skip the dunder methods or not
|
||||||
|
nodes_shim_dir = {k: getattr(base_nodes, k) for k in dir(base_nodes) if not k.startswith("__")}
|
||||||
|
nodes_shim_dir['NODE_CLASS_MAPPINGS'] = _NodeClassMappingsShim()
|
||||||
|
nodes_shim_dir['EXTENSION_WEB_DIRS'] = {}
|
||||||
|
|
||||||
|
nodes_shim = _NodeShim()
|
||||||
|
for k, v in nodes_shim_dir.items():
|
||||||
|
setattr(nodes_shim, k, v)
|
||||||
|
|
||||||
|
sys.modules['nodes'] = nodes_shim
|
||||||
|
|
||||||
comfyui_version = types.ModuleType('comfyui_version', '')
|
comfyui_version = types.ModuleType('comfyui_version', '')
|
||||||
setattr(comfyui_version, "__version__", __version__)
|
setattr(comfyui_version, "__version__", __version__)
|
||||||
sys.modules['comfyui_version'] = comfyui_version
|
sys.modules['comfyui_version'] = comfyui_version
|
||||||
@ -76,3 +178,97 @@ def prepare_vanilla_environment():
|
|||||||
threading.Thread.start = patched_start
|
threading.Thread.start = patched_start
|
||||||
setattr(threading.Thread.start, '__is_patched_by_us', True)
|
setattr(threading.Thread.start, '__is_patched_by_us', True)
|
||||||
logger.debug("Patched `threading.Thread.start` to propagate contextvars.")
|
logger.debug("Patched `threading.Thread.start` to propagate contextvars.")
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def patch_pip_install_subprocess_run():
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
original_subprocess_run = subprocess.run
|
||||||
|
|
||||||
|
def custom_side_effect(*args, **kwargs):
|
||||||
|
command_list = args[0] if args else []
|
||||||
|
|
||||||
|
# from easy-use
|
||||||
|
is_pip_install_call = (
|
||||||
|
isinstance(command_list, list) and
|
||||||
|
len(command_list) == 6 and
|
||||||
|
command_list[0] == sys.executable and
|
||||||
|
command_list[1] == '-s' and
|
||||||
|
command_list[2] == '-m' and
|
||||||
|
command_list[3] == 'pip' and
|
||||||
|
command_list[4] == 'install' and
|
||||||
|
isinstance(command_list[5], str)
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_pip_install_call:
|
||||||
|
package_name = command_list[5]
|
||||||
|
logger.info(f"Intercepted and mocked `pip install` for: {package_name}")
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.returncode = 0
|
||||||
|
return mock_result
|
||||||
|
else:
|
||||||
|
return original_subprocess_run(*args, **kwargs)
|
||||||
|
|
||||||
|
with patch('subprocess.run') as mock_run:
|
||||||
|
mock_run.side_effect = custom_side_effect
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def patch_pip_install_popen():
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
original_subprocess_popen = subprocess.Popen
|
||||||
|
|
||||||
|
def custom_side_effect(*args, **kwargs):
|
||||||
|
command_list = args[0] if args else []
|
||||||
|
|
||||||
|
is_pip_install_call = (
|
||||||
|
isinstance(command_list, list) and
|
||||||
|
len(command_list) >= 5 and
|
||||||
|
command_list[0] == sys.executable and
|
||||||
|
command_list[1] == "-m" and
|
||||||
|
command_list[2] == "pip" and
|
||||||
|
command_list[3] == "install" and
|
||||||
|
# special case nunchaku
|
||||||
|
"nunchaku" not in command_list[4:]
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_pip_install_call:
|
||||||
|
package_names = command_list[4:]
|
||||||
|
logger.info(f"Intercepted and mocked `subprocess.Popen` for: pip install {' '.join(package_names)}")
|
||||||
|
|
||||||
|
mock_popen_instance = MagicMock()
|
||||||
|
# make stdout and stderr empty iterables so loops over them complete immediately.
|
||||||
|
mock_popen_instance.stdout = []
|
||||||
|
mock_popen_instance.stderr = []
|
||||||
|
|
||||||
|
return mock_popen_instance
|
||||||
|
else:
|
||||||
|
return original_subprocess_popen(*args, **kwargs)
|
||||||
|
|
||||||
|
with patch('subprocess.Popen') as mock_popen:
|
||||||
|
mock_popen.side_effect = custom_side_effect
|
||||||
|
yield mock_popen
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def vanilla_environment_node_execution_hooks():
|
||||||
|
# this handles activating the NODE_CLASS_MAPPINGS shim
|
||||||
|
from comfy.execution_context import current_execution_context
|
||||||
|
ctx = current_execution_context()
|
||||||
|
|
||||||
|
if 'nodes' in sys.modules and isinstance(sys.modules['nodes'], _NodeShim):
|
||||||
|
nodes_shim: _NodeShim = sys.modules['nodes']
|
||||||
|
try:
|
||||||
|
nodes_shim.activate()
|
||||||
|
|
||||||
|
block_installs = ctx and ctx.configuration and ctx.configuration.block_runtime_package_installation is True
|
||||||
|
with (
|
||||||
|
patch_pip_install_subprocess_run() if block_installs else nullcontext(),
|
||||||
|
patch_pip_install_popen() if block_installs else nullcontext(),
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
nodes_shim.deactivate()
|
||||||
|
else:
|
||||||
|
yield
|
||||||
|
|||||||
@ -27,7 +27,7 @@ dependencies = [
|
|||||||
"torchsde>=0.2.6",
|
"torchsde>=0.2.6",
|
||||||
"einops>=0.6.0",
|
"einops>=0.6.0",
|
||||||
"open-clip-torch>=2.24.0",
|
"open-clip-torch>=2.24.0",
|
||||||
"transformers>=4.46.0,!=4.53.0,!=4.53.1,!=4.53.2",
|
"transformers>=4.46.0,!=4.53.0,!=4.53.1,!=4.53.2,!=4.57.0",
|
||||||
"tokenizers>=0.13.3",
|
"tokenizers>=0.13.3",
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
"peft>=0.10.0",
|
"peft>=0.10.0",
|
||||||
@ -52,7 +52,7 @@ dependencies = [
|
|||||||
"tqdm",
|
"tqdm",
|
||||||
"protobuf>=3.20.0,<5.0.0",
|
"protobuf>=3.20.0,<5.0.0",
|
||||||
"psutil",
|
"psutil",
|
||||||
"ConfigArgParse",
|
"ConfigArgParse>=1.7.1",
|
||||||
"aio-pika",
|
"aio-pika",
|
||||||
"pyjwt[crypto]",
|
"pyjwt[crypto]",
|
||||||
"kornia>=0.7.0",
|
"kornia>=0.7.0",
|
||||||
@ -113,6 +113,8 @@ dependencies = [
|
|||||||
"stringzilla<4.2.0",
|
"stringzilla<4.2.0",
|
||||||
"requests_cache",
|
"requests_cache",
|
||||||
"universal_pathlib",
|
"universal_pathlib",
|
||||||
|
# yanked propcache is omitted
|
||||||
|
"propcache!=0.4.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
|||||||
@ -13,6 +13,9 @@ from typing import List, Any, Generator
|
|||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
from comfy.cli_args import default_configuration
|
||||||
|
from comfy.execution_context import context_configuration
|
||||||
|
|
||||||
os.environ['OTEL_METRICS_EXPORTER'] = 'none'
|
os.environ['OTEL_METRICS_EXPORTER'] = 'none'
|
||||||
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
||||||
os.environ["HF_XET_HIGH_PERFORMANCE"] = "True"
|
os.environ["HF_XET_HIGH_PERFORMANCE"] = "True"
|
||||||
@ -27,11 +30,8 @@ logging.getLogger("aio_pika").setLevel(logging.CRITICAL + 1)
|
|||||||
|
|
||||||
def run_server(server_arguments: Configuration):
|
def run_server(server_arguments: Configuration):
|
||||||
from comfy.cmd.main import _start_comfyui
|
from comfy.cmd.main import _start_comfyui
|
||||||
from comfy.cli_args import args
|
|
||||||
import asyncio
|
import asyncio
|
||||||
for arg, value in server_arguments.items():
|
asyncio.run(_start_comfyui(configuration=server_arguments))
|
||||||
args[arg] = value
|
|
||||||
asyncio.run(_start_comfyui())
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=False)
|
@pytest.fixture(scope="function", autouse=False)
|
||||||
@ -140,7 +140,7 @@ def comfy_background_server(tmp_path_factory) -> Generator[tuple[Configuration,
|
|||||||
tmp_path = tmp_path_factory.mktemp("comfy_background_server")
|
tmp_path = tmp_path_factory.mktemp("comfy_background_server")
|
||||||
# Start server
|
# Start server
|
||||||
|
|
||||||
configuration = Configuration()
|
configuration = default_configuration()
|
||||||
configuration.listen = "localhost"
|
configuration.listen = "localhost"
|
||||||
configuration.output_directory = str(tmp_path)
|
configuration.output_directory = str(tmp_path)
|
||||||
configuration.input_directory = str(tmp_path)
|
configuration.input_directory = str(tmp_path)
|
||||||
|
|||||||
288
tests/distributed/test_counter.py
Normal file
288
tests/distributed/test_counter.py
Normal file
@ -0,0 +1,288 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from threading import Thread, Barrier
|
||||||
|
from pathlib import Path
|
||||||
|
import asyncio
|
||||||
|
import contextvars
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
import pytest
|
||||||
|
from testcontainers.core.container import DockerContainer
|
||||||
|
from testcontainers.core.wait_strategies import LogMessageWaitStrategy
|
||||||
|
|
||||||
|
from comfy.component_model.file_counter import FileCounter
|
||||||
|
from comfy.component_model.folder_path_types import FolderNames
|
||||||
|
from comfy.execution_context import context_folder_names_and_paths
|
||||||
|
from comfy.cmd.folder_paths import init_default_paths
|
||||||
|
|
||||||
|
|
||||||
|
def is_tool(name):
|
||||||
|
"""Check whether `name` is on PATH and marked as executable."""
|
||||||
|
return shutil.which(name) is not None
|
||||||
|
|
||||||
|
|
||||||
|
def run_command(command, check=True):
|
||||||
|
"""Helper to run a shell command."""
|
||||||
|
try:
|
||||||
|
return subprocess.run(command, shell=True, check=check, capture_output=True, text=True)
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print(f"Command failed: {command}")
|
||||||
|
print(f"--- STDOUT ---\n{e.stdout}")
|
||||||
|
print(f"--- STDERR ---\n{e.stderr}")
|
||||||
|
print("--------------")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(
|
||||||
|
params=[
|
||||||
|
pytest.param("local", id="local_filesystem"),
|
||||||
|
pytest.param(
|
||||||
|
"nfs",
|
||||||
|
id="nfs_share",
|
||||||
|
marks=pytest.mark.skipif(
|
||||||
|
not sys.platform.startswith("linux")
|
||||||
|
or not is_tool("mount.nfs")
|
||||||
|
or not is_tool("sudo")
|
||||||
|
or not os.path.exists("/sys/module/nfsd"),
|
||||||
|
reason="NFS tests require sudo, nfs-common, and the 'nfsd' kernel module to be loaded.",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"samba",
|
||||||
|
id="samba_share",
|
||||||
|
marks=pytest.mark.skipif(
|
||||||
|
not sys.platform.startswith("linux")
|
||||||
|
or not is_tool("mount.cifs")
|
||||||
|
or not is_tool("sudo"),
|
||||||
|
reason="Samba tests require sudo on Linux with cifs-utils installed.",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def counter_path_factory(request, tmp_path_factory):
|
||||||
|
"""A parameterized fixture to provide paths on local, NFS, and Samba filesystems."""
|
||||||
|
if request.param == "local":
|
||||||
|
yield lambda name: str(tmp_path_factory.mktemp("local_test") / name)
|
||||||
|
return
|
||||||
|
|
||||||
|
mount_point = tmp_path_factory.mktemp(f"mount_point_{request.param}")
|
||||||
|
|
||||||
|
if request.param == "nfs":
|
||||||
|
# 1. Create the host directory that will be mounted into the container.
|
||||||
|
nfs_source = tmp_path_factory.mktemp("nfs_source")
|
||||||
|
# 2. FIX: Set permissions on the *host* directory.
|
||||||
|
os.chmod(str(nfs_source), 0o777)
|
||||||
|
|
||||||
|
# 3. FIX: Use the new container's required path: /mnt/data
|
||||||
|
container_path = "/mnt/data"
|
||||||
|
|
||||||
|
# 4. FIX: Change to the new container image and configuration
|
||||||
|
nfs_container = DockerContainer("ghcr.io/normal-computing/nfs-server:latest").with_env(
|
||||||
|
"NFS_SERVER_ALLOWED_CLIENTS", "*"
|
||||||
|
).with_kwargs(privileged=True).with_exposed_ports(2049).with_volume_mapping(
|
||||||
|
str(nfs_source), container_path, mode="rw" # Mount to /mnt/data
|
||||||
|
)
|
||||||
|
|
||||||
|
nfs_container.start()
|
||||||
|
|
||||||
|
# 5. FIX: Wait for the new container's export log message
|
||||||
|
# (and remove the timeout as requested)
|
||||||
|
nfs_container.waiting_for(
|
||||||
|
LogMessageWaitStrategy(r"exporting /mnt/data")
|
||||||
|
)
|
||||||
|
|
||||||
|
request.addfinalizer(lambda: nfs_container.stop())
|
||||||
|
|
||||||
|
ip_address = nfs_container.get_container_host_ip()
|
||||||
|
nfs_port = nfs_container.get_exposed_port(2049)
|
||||||
|
try:
|
||||||
|
# 6. FIX: Mount using the new container's command format
|
||||||
|
# (mounts root ":" and uses "-t nfs4")
|
||||||
|
run_command(f"sleep 1 && sudo mount -t nfs4 -o proto=tcp,port={nfs_port} {ip_address}:/ {mount_point}")
|
||||||
|
yield lambda name: str(mount_point / name)
|
||||||
|
finally:
|
||||||
|
run_command(f"sudo umount {mount_point}", check=False)
|
||||||
|
|
||||||
|
elif request.param == "samba":
|
||||||
|
# 1. Create the host directory.
|
||||||
|
samba_source = tmp_path_factory.mktemp("samba_source")
|
||||||
|
# 2. Set permissions on the *host* directory.
|
||||||
|
os.chmod(str(samba_source), 0o777)
|
||||||
|
|
||||||
|
share_name = "storage"
|
||||||
|
|
||||||
|
# 3. FIX: Add the NAME environment variable to tell the container
|
||||||
|
# to create a share named "storage".
|
||||||
|
samba_container = DockerContainer("dockurr/samba:latest").with_env(
|
||||||
|
"RW", "yes"
|
||||||
|
).with_env(
|
||||||
|
"NAME", share_name # <-- This is the crucial fix
|
||||||
|
).with_exposed_ports(445).with_volume_mapping(
|
||||||
|
str(samba_source), "/storage", mode="rw" # This maps the host dir to the internal /storage path
|
||||||
|
)
|
||||||
|
|
||||||
|
samba_container.start()
|
||||||
|
|
||||||
|
# 4. Wait for the correct log message
|
||||||
|
# (and remove the timeout as requested)
|
||||||
|
samba_container.waiting_for(
|
||||||
|
LogMessageWaitStrategy(r"smbd version .* started")
|
||||||
|
)
|
||||||
|
|
||||||
|
request.addfinalizer(lambda: samba_container.stop())
|
||||||
|
|
||||||
|
ip_address = samba_container.get_container_host_ip()
|
||||||
|
samba_port = samba_container.get_exposed_port(445)
|
||||||
|
try:
|
||||||
|
# 5. FIX: Mount with the default username/password, not as guest.
|
||||||
|
run_command(f"sleep 1 && sudo mount -t cifs -o username=samba,password=secret,vers=3.0,port={samba_port},uid=$(id -u),gid=$(id -g) //{ip_address}/{share_name} {mount_point}", check=True)
|
||||||
|
yield lambda name: str(mount_point / name)
|
||||||
|
finally:
|
||||||
|
run_command(f"sudo umount {mount_point}", check=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_initial_state(counter_path_factory):
|
||||||
|
"""Verify initial state and file creation."""
|
||||||
|
counter_file = counter_path_factory("counter.txt")
|
||||||
|
lock_file = counter_path_factory("counter.txt.lock")
|
||||||
|
|
||||||
|
assert not os.path.exists(counter_file)
|
||||||
|
assert not os.path.exists(lock_file)
|
||||||
|
|
||||||
|
counter = FileCounter(str(counter_file))
|
||||||
|
assert counter.get_and_increment() == 0
|
||||||
|
assert os.path.exists(counter_file)
|
||||||
|
with open(counter_file, "r") as f:
|
||||||
|
assert f.read() == "1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_mkdirs(counter_path_factory):
|
||||||
|
counter_file = counter_path_factory("new_dir/counter.txt")
|
||||||
|
assert not os.path.exists(os.path.dirname(counter_file))
|
||||||
|
|
||||||
|
counter = FileCounter(str(counter_file))
|
||||||
|
assert counter.get_and_increment() == 0
|
||||||
|
assert counter.get_and_increment() == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_increment_and_decrement(counter_path_factory):
|
||||||
|
"""Test the increment and decrement logic."""
|
||||||
|
counter_file = counter_path_factory("counter.txt")
|
||||||
|
counter = FileCounter(str(counter_file))
|
||||||
|
|
||||||
|
assert counter.get_and_increment() == 0 # val: 0, new_val: 1
|
||||||
|
assert counter.get_and_increment() == 1 # val: 1, new_val: 2
|
||||||
|
assert counter.get_and_increment() == 2 # val: 2, new_val: 3
|
||||||
|
|
||||||
|
assert counter.decrement_and_get() == 2 # val: 3, new_val: 2
|
||||||
|
assert counter.decrement_and_get() == 1 # val: 2, new_val: 1
|
||||||
|
|
||||||
|
assert counter.get_and_increment() == 1 # val: 1, new_val: 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_instances_same_path(counter_path_factory):
|
||||||
|
"""Verify that multiple FileCounter instances on the same path work correctly."""
|
||||||
|
counter_file = counter_path_factory("counter.txt")
|
||||||
|
counter1 = FileCounter(str(counter_file))
|
||||||
|
counter2 = FileCounter(str(counter_file))
|
||||||
|
|
||||||
|
assert counter1.get_and_increment() == 0
|
||||||
|
assert counter2.get_and_increment() == 1
|
||||||
|
assert counter1.decrement_and_get() == 1
|
||||||
|
assert counter2.get_and_increment() == 1
|
||||||
|
|
||||||
|
with open(counter_file, "r") as f:
|
||||||
|
assert f.read() == "2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_multithreaded_access(counter_path_factory):
|
||||||
|
"""Ensure atomicity with multiple threads."""
|
||||||
|
counter_file = counter_path_factory("counter.txt")
|
||||||
|
counter = FileCounter(str(counter_file))
|
||||||
|
num_threads = 10
|
||||||
|
increments_per_thread = 100
|
||||||
|
|
||||||
|
def worker():
|
||||||
|
for _ in range(increments_per_thread):
|
||||||
|
counter.get_and_increment()
|
||||||
|
|
||||||
|
threads = [Thread(target=worker) for _ in range(num_threads)]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
with open(counter_file, "r") as f:
|
||||||
|
final_value = int(f.read())
|
||||||
|
assert final_value == num_threads * increments_per_thread
|
||||||
|
|
||||||
|
|
||||||
|
def test_context_manager(counter_path_factory):
|
||||||
|
"""Test that the counter can be used as a context manager."""
|
||||||
|
counter_file = counter_path_factory("counter.txt")
|
||||||
|
counter = FileCounter(str(counter_file))
|
||||||
|
|
||||||
|
# Initial state should be 0
|
||||||
|
assert counter.get_and_increment() == 0
|
||||||
|
with open(counter_file) as f:
|
||||||
|
assert f.read() == "1"
|
||||||
|
|
||||||
|
with counter as wrapper:
|
||||||
|
# The wrapper's value is the original count before increment.
|
||||||
|
assert wrapper.value == 1
|
||||||
|
# It can also be used as an integer directly.
|
||||||
|
assert int(wrapper) == 1
|
||||||
|
with open(counter_file) as f:
|
||||||
|
assert f.read() == "2" # Inside context, value is 2
|
||||||
|
|
||||||
|
with open(counter_file) as f:
|
||||||
|
assert f.read() == "1" # Exited context, decremented back to 1
|
||||||
|
# After exit, the wrapper's 'ctr' attribute holds the new value.
|
||||||
|
assert wrapper.ctr == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_cleanup_temp_multithreaded(tmp_path):
|
||||||
|
"""
|
||||||
|
Test that cleanup_temp correctly deletes the temp directory only
|
||||||
|
after the last thread has exited the context.
|
||||||
|
"""
|
||||||
|
# 1. Use the application's context to define the temp directory for this test.
|
||||||
|
# This is a cleaner approach than mocking.
|
||||||
|
base_dir = tmp_path / "base"
|
||||||
|
temp_dir_override = base_dir / "temp"
|
||||||
|
fn = FolderNames(base_paths=[base_dir])
|
||||||
|
init_default_paths(fn, base_paths_from_configuration=False)
|
||||||
|
# Override the default temp path
|
||||||
|
fn.temp_directory = temp_dir_override
|
||||||
|
|
||||||
|
from comfy.component_model.file_counter import cleanup_temp
|
||||||
|
|
||||||
|
num_threads = 5
|
||||||
|
# Barrier to make threads wait for each other before exiting.
|
||||||
|
barrier = Barrier(num_threads)
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
def worker():
|
||||||
|
"""The task for each thread. It enters the cleanup context and waits."""
|
||||||
|
with cleanup_temp():
|
||||||
|
# The temp directory and counter file should exist inside the context.
|
||||||
|
assert temp_dir_override.exists()
|
||||||
|
assert (temp_dir_override / "counter.txt").exists()
|
||||||
|
# After exiting, the directory should still exist until the last thread is done.
|
||||||
|
|
||||||
|
# Use the context manager to set the folder paths for the current async context.
|
||||||
|
with context_folder_names_and_paths(fn):
|
||||||
|
# Capture the current context, which now includes the folder_paths settings.
|
||||||
|
|
||||||
|
# Run the worker function in a thread pool, applying the captured context to each task.
|
||||||
|
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||||
|
|
||||||
|
tasks = [loop.run_in_executor(executor, contextvars.copy_context().run, worker) for _ in range(num_threads)]
|
||||||
|
# Wait for all threads to complete their work.
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# After all threads have joined, the counter should be 0, and the directory deleted.
|
||||||
|
assert not temp_dir_override.exists()
|
||||||
|
# the base dir is not going to be deleted
|
||||||
|
assert base_dir.exists()
|
||||||
@ -16,6 +16,7 @@ from typing import List, Dict, Any, Generator
|
|||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from comfy.cli_args import default_configuration
|
||||||
from comfy.cli_args_types import Configuration
|
from comfy.cli_args_types import Configuration
|
||||||
from comfy_execution.graph_utils import GraphBuilder
|
from comfy_execution.graph_utils import GraphBuilder
|
||||||
from .test_execution import ComfyClient, RunResult
|
from .test_execution import ComfyClient, RunResult
|
||||||
@ -188,7 +189,7 @@ class TestProgressIsolation:
|
|||||||
tmp_path = tmp_path_factory.mktemp("comfy_background_server")
|
tmp_path = tmp_path_factory.mktemp("comfy_background_server")
|
||||||
# Start server
|
# Start server
|
||||||
|
|
||||||
configuration = Configuration()
|
configuration = default_configuration()
|
||||||
configuration.listen = args_pytest["listen"]
|
configuration.listen = args_pytest["listen"]
|
||||||
configuration.port = args_pytest["port"]
|
configuration.port = args_pytest["port"]
|
||||||
configuration.cpu = True
|
configuration.cpu = True
|
||||||
@ -205,7 +206,7 @@ testing_pack:
|
|||||||
yaml_path = str(tmp_path_factory.mktemp("comfy_background_server") / "extra_nodes.yaml")
|
yaml_path = str(tmp_path_factory.mktemp("comfy_background_server") / "extra_nodes.yaml")
|
||||||
with open(yaml_path, mode="wt") as f:
|
with open(yaml_path, mode="wt") as f:
|
||||||
f.write(extra_nodes)
|
f.write(extra_nodes)
|
||||||
configuration.extra_model_paths_config = [str(yaml_path)]
|
configuration.extra_model_paths_config = [yaml_path]
|
||||||
|
|
||||||
yield from comfy_background_server_from_config(configuration)
|
yield from comfy_background_server_from_config(configuration)
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import importlib.resources
|
import importlib.resources
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from importlib.abc import Traversable
|
from importlib.abc import Traversable
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -11,6 +12,8 @@ from comfy.model_downloader_types import CivitFile, HuggingFile
|
|||||||
from comfy_extras.nodes.nodes_audio import TorchAudioNotFoundError
|
from comfy_extras.nodes.nodes_audio import TorchAudioNotFoundError
|
||||||
from . import workflows
|
from . import workflows
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=False)
|
@pytest.fixture(scope="function", autouse=False)
|
||||||
async def client(tmp_path_factory) -> Comfy:
|
async def client(tmp_path_factory) -> Comfy:
|
||||||
@ -35,6 +38,7 @@ async def test_workflow(workflow_name: str, workflow_file: Traversable, has_gpu:
|
|||||||
|
|
||||||
prompt = Prompt.validate(workflow)
|
prompt = Prompt.validate(workflow)
|
||||||
# todo: add all the models we want to test a bit m2ore elegantly
|
# todo: add all the models we want to test a bit m2ore elegantly
|
||||||
|
outputs = {}
|
||||||
try:
|
try:
|
||||||
outputs = await client.queue_prompt(prompt)
|
outputs = await client.queue_prompt(prompt)
|
||||||
except TorchAudioNotFoundError:
|
except TorchAudioNotFoundError:
|
||||||
@ -46,8 +50,14 @@ async def test_workflow(workflow_name: str, workflow_file: Traversable, has_gpu:
|
|||||||
elif any(v.class_type == "SaveAudio" for v in prompt.values()):
|
elif any(v.class_type == "SaveAudio" for v in prompt.values()):
|
||||||
save_audio_node_id = next(key for key in prompt if prompt[key].class_type == "SaveAudio")
|
save_audio_node_id = next(key for key in prompt if prompt[key].class_type == "SaveAudio")
|
||||||
assert outputs[save_audio_node_id]["audio"][0]["filename"] is not None
|
assert outputs[save_audio_node_id]["audio"][0]["filename"] is not None
|
||||||
|
elif any(v.class_type == "SaveAnimatedWEBP" for v in prompt.values()):
|
||||||
|
save_video_node_id = next(key for key in prompt if prompt[key].class_type == "SaveAnimatedWEBP")
|
||||||
|
assert outputs[save_video_node_id]["images"][0]["filename"] is not None
|
||||||
elif any(v.class_type == "PreviewString" for v in prompt.values()):
|
elif any(v.class_type == "PreviewString" for v in prompt.values()):
|
||||||
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "PreviewString")
|
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "PreviewString")
|
||||||
output_str = outputs[save_image_node_id]["string"][0]
|
output_str = outputs[save_image_node_id]["string"][0]
|
||||||
assert output_str is not None
|
assert output_str is not None
|
||||||
assert len(output_str) > 0
|
assert len(output_str) > 0
|
||||||
|
else:
|
||||||
|
assert len(outputs) > 0
|
||||||
|
logger.warning(f"test {workflow_name} did not have a node that could be checked for output")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user