Improvements to compatibility with custom nodes, distributed

backends and other changes

 - remove uv.lock since it will not be used in most cases for installation
 - add cli args to prevent some custom nodes from installing packages at runtime
 - temp directories can now be shared between workers without being deleted
 - propcache yanked is now in the dependencies
 - fix configuration arguments loading in some tests
This commit is contained in:
doctorpangloss 2025-11-04 17:40:19 -08:00
parent d9e3ba4bec
commit 98ae55b059
16 changed files with 678 additions and 7173 deletions

View File

@ -31,10 +31,12 @@ class LogInterceptor(io.TextIOWrapper):
if isinstance(data, str) and data.startswith("\r") and len(logs) > 0 and not logs[-1]["m"].endswith("\n"):
logs.pop()
logs.append(entry)
super().write(data)
if not self.closed:
super().write(data)
def flush(self):
super().flush()
if not self.closed:
super().flush()
for cb in self._flush_callbacks:
cb(self._logs_since_flush)
self._logs_since_flush = []
@ -56,7 +58,8 @@ def on_flush(callback):
class StackTraceLogger(logging.Logger):
def _log(self, level, msg, args, exc_info=None, extra=None, stack_info=False, stacklevel=1):
if level >= logging.ERROR:
if not stack_info and level >= logging.ERROR and exc_info is None:
# create a stack even when there is no exception
stack_info = True
super()._log(level, msg, args, exc_info, extra, stack_info, stacklevel=stacklevel + 1)

View File

@ -273,6 +273,12 @@ def _create_parser() -> EnhancedConfigArgParser:
help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
)
parser.add_argument(
"--block-runtime-package-installation",
action="store_true",
help="When set, custom nodes like ComfyUI Manager, Easy Use, Nunchaku and others will not be able to use pip or uv to install packages at runtime (experimental)."
)
default_db_url = db_config()
parser.add_argument("--database-url", type=str, default=default_db_url, help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
parser.add_argument("--workflows", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Execute the API workflow(s) specified in the provided files. For each workflow, its outputs will be printed to a line to standard out. Application logging will be redirected to standard error. Use `-` to signify standard in.")

View File

@ -168,6 +168,7 @@ class Configuration(dict):
blacklist_custom_nodes (list[str]): Specify custom node folders to never load. Accepts shell-style globs.
whitelist_custom_nodes (list[str]): Specify custom node folders to load even when --disable-all-custom-nodes is enabled.
default_device (Optional[int]): Set the id of the default device, all other devices will stay visible.
block_runtime_package_installation (Optional[bool]): When set, custom nodes like ComfyUI Manager, Easy Use, Nunchaku and others will not be able to use pip or uv to install packages at runtime (experimental).
"""
def __init__(self, **kwargs):
@ -286,6 +287,7 @@ class Configuration(dict):
self.comfy_api_base: str = "https://api.comfy.org"
self.database_url: str = db_config()
self.default_device: Optional[int] = None
self.block_runtime_package_installation = None
for key, value in kwargs.items():
self[key] = value
@ -315,6 +317,8 @@ class Configuration(dict):
super().update(__m, **kwargs)
for k, v in changes.items():
self._notify_observers(k, v)
# make this more pythonic
return self
def register_observer(self, observer: ConfigObserver):
self._observers.append(observer)

View File

@ -19,19 +19,23 @@ from typing import List, Optional, Tuple, Literal
# order matters
from .main_pre import tracer
import torch
from frozendict import frozendict
from comfy_execution.graph_types import FrozenTopologicalSort, Input
from opentelemetry.trace import get_current_span, StatusCode, Status
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, \
make_locked_method_func
from comfy_api.latest import io
from comfy_compatibility.vanilla import vanilla_environment_node_execution_hooks
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, \
DependencyAwareCache, \
BasicCache
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy_execution.graph_types import FrozenTopologicalSort
from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, \
WebUIProgressHandler, \
ProgressRegistry
from comfy_execution.utils import CurrentNodeContext
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
from comfy_api.latest import io
from ..execution_context import current_execution_context, context_set_execution_list_and_inputs
from comfy_execution.validation import validate_node_input
from .. import interruption
from .. import model_management
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
@ -44,13 +48,11 @@ from ..component_model.module_property import create_module_properties
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus, \
ExecutionStatusAsDict
from ..execution_context import context_execute_node, context_execute_prompt
from ..execution_context import current_execution_context, context_set_execution_list_and_inputs
from ..execution_ext import should_panic_on_exception
from ..node_requests_caching import use_requests_caching
from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
from ..nodes_context import get_nodes, vanilla_node_execution_environment
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler, \
ProgressRegistry
from comfy_execution.validation import validate_node_input
from ..nodes_context import get_nodes
_module_properties = create_module_properties()
logger = logging.getLogger(__name__)
@ -474,9 +476,11 @@ async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, ca
:param pending_subgraph_results:
:return:
"""
with (context_execute_node(node_id),
vanilla_node_execution_environment(),
use_requests_caching()):
with (
context_execute_node(node_id),
vanilla_environment_node_execution_hooks(),
use_requests_caching(),
):
return await _execute(server, dynprompt, caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)

View File

@ -1,7 +1,7 @@
import asyncio
import contextvars
import gc
import itertools
import logging
import os
import shutil
@ -11,8 +11,10 @@ from pathlib import Path
from typing import Optional
# main_pre must be the earliest import
from .main_pre import args
from .main_pre import tracer
from ..cli_args_types import Configuration
from ..component_model.file_counter import cleanup_temp as fc_cleanup_temp
from ..execution_context import current_execution_context
from . import hook_breaker_ac10a0
from .extra_model_paths import load_extra_path_config
from .. import model_management
@ -51,6 +53,7 @@ async def _prompt_worker(q: AbstractPromptQueue, server_instance: server_module.
from ..cmd import execution
from ..component_model import queue_types
from .. import model_management
args = current_execution_context().configuration
cache_type = execution.CacheType.CLASSIC
if args.cache_lru > 0:
cache_type = execution.CacheType.LRU
@ -147,10 +150,14 @@ def setup_database():
init_db()
async def _start_comfyui(from_script_dir: Optional[Path] = None):
async def _start_comfyui(from_script_dir: Optional[Path] = None, configuration: Optional[Configuration] = None):
from ..execution_context import context_configuration
from ..cli_args import cli_args_configuration
with context_configuration(cli_args_configuration()):
configuration = configuration or cli_args_configuration()
with (
context_configuration(configuration),
fc_cleanup_temp()
):
await __start_comfyui(from_script_dir=from_script_dir)
@ -159,6 +166,7 @@ async def __start_comfyui(from_script_dir: Optional[Path] = None):
Runs ComfyUI's frontend and backend like upstream.
:param from_script_dir: when set to a path, assumes that you are running ComfyUI's legacy main.py entrypoint at the root of the git repository located at the path
"""
args = current_execution_context().configuration
if not from_script_dir:
os_getcwd = os.getcwd()
else:
@ -168,7 +176,6 @@ async def __start_comfyui(from_script_dir: Optional[Path] = None):
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
logger.debug(f"Setting temp directory to: {temp_dir}")
folder_paths.set_temp_directory(temp_dir)
cleanup_temp()
if args.user_directory:
user_dir = os.path.abspath(args.user_directory)
@ -305,7 +312,6 @@ async def __start_comfyui(from_script_dir: Optional[Path] = None):
finally:
if distributed:
await q.close()
cleanup_temp()
def entrypoint():

View File

@ -0,0 +1,90 @@
import shutil
from contextlib import contextmanager
from pathlib import Path
import filelock
class ContextWrapper:
"""A wrapper to hold context manager values for entry and exit."""
def __init__(self, value):
self.value = value
self.ctr = None
def __int__(self):
return self.value
class FileCounter:
def __init__(self, path):
self.path = Path(path)
async def __aenter__(self):
wrapper = ContextWrapper(self.get_and_increment())
self._context_wrapper = wrapper
return wrapper
async def __aexit__(self, exc_type, exc_val, exc_tb):
self._context_wrapper.ctr = self.decrement_and_get()
def __enter__(self):
"""Increment on entering the context and return a wrapper."""
wrapper = ContextWrapper(self.get_and_increment())
self._context_wrapper = wrapper
return wrapper
def __exit__(self, exc_type, exc_val, exc_tb):
"""Decrement on exiting the context and update the wrapper."""
self._context_wrapper.ctr = self.decrement_and_get()
def _read_and_write(self, operation):
lock = filelock.FileLock(f"{self.path}.lock")
with lock:
count = 0
try:
with open(self.path, 'r') as f:
content = f.read().strip()
if content:
count = int(content)
except FileNotFoundError:
# File doesn't exist, will be created with initial value.
pass
except ValueError:
# File is corrupt or empty, treat as 0 and overwrite.
pass
original_count = count
new_count = operation(count)
self.path.parent.mkdir(parents=True, exist_ok=True)
with open(self.path, 'w') as f:
f.write(str(new_count))
return original_count, new_count
def get_and_increment(self):
"""Atomically reads the current value, increments it, and returns the original value."""
original_count, _ = self._read_and_write(lambda x: x + 1)
return original_count
def decrement_and_get(self):
"""Atomically decrements the value and returns the new value."""
_, new_count = self._read_and_write(lambda x: x - 1)
return new_count
@contextmanager
def cleanup_temp():
from ..cli_args import args
from ..cmd import folder_paths
tmp_dir = Path(args.temp_directory or folder_paths.get_temp_directory())
counter_path = tmp_dir / "counter.txt"
fc_i = -1
try:
with FileCounter(counter_path) as fc:
yield
fc_i = fc.ctr
finally:
if fc_i == 0 and tmp_dir.is_dir():
shutil.rmtree(tmp_dir, ignore_errors=True)

View File

@ -1,6 +1,7 @@
import asyncio
from ..cmd.main_pre import args
from ..component_model.file_counter import cleanup_temp
from ..component_model.entrypoints_common import configure_application_paths, executor_from_args
@ -10,19 +11,27 @@ async def main():
args.distributed_queue_worker = True
args.distributed_queue_frontend = False
# in workers, there is a different default
if args.block_runtime_package_installation is None:
args.block_runtime_package_installation = True
assert args.distributed_queue_connection_uri is not None, "Set the --distributed-queue-connection-uri argument to your RabbitMQ server"
configure_application_paths(args)
executor = await executor_from_args(args)
async with DistributedPromptWorker(connection_uri=args.distributed_queue_connection_uri,
queue_name=args.distributed_queue_name,
executor=executor):
stop = asyncio.Event()
try:
await stop.wait()
except asyncio.CancelledError:
pass
async with (
DistributedPromptWorker(connection_uri=args.distributed_queue_connection_uri,
queue_name=args.distributed_queue_name,
executor=executor),
):
with cleanup_temp():
stop = asyncio.Event()
try:
await stop.wait()
except (asyncio.CancelledError, InterruptedError, KeyboardInterrupt):
pass
def entrypoint():

View File

@ -8,18 +8,19 @@ import os
import sys
import time
import types
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from os.path import join, basename, dirname, isdir, isfile, exists, abspath, split, splitext, realpath
from typing import Iterable, Any, Generator
from unittest.mock import patch
from unittest.mock import patch, MagicMock
from comfy_compatibility.vanilla import prepare_vanilla_environment
from comfy_compatibility.vanilla import prepare_vanilla_environment, patch_pip_install_subprocess_run, patch_pip_install_popen
from . import base_nodes
from .comfyui_v3_package_imports import _comfy_entrypoint_upstream_v3_imports
from .package_typing import ExportedNodes
from ..cmd import folder_paths
from ..component_model.plugins import prompt_server_instance_routes
from ..distributed.server_stub import ServerStub
from ..execution_context import current_execution_context
logger = logging.getLogger(__name__)
@ -44,6 +45,10 @@ class StreamToLogger:
# The logger handles its own flushing, so this can be a no-op.
pass
@property
def encoding(self):
return "utf-8"
class _PromptServerStub(ServerStub):
def __init__(self):
@ -140,6 +145,7 @@ def _exec_mitigations(module: types.ModuleType, module_path: str) -> Generator[E
"comfyui-manager",
"comfyui_ryanonyheinside",
"comfyui-easy-use",
"comfyui_custom_nodes_alekpet",
):
from ..cmd import folder_paths
old_file = folder_paths.__file__
@ -147,9 +153,15 @@ def _exec_mitigations(module: types.ModuleType, module_path: str) -> Generator[E
try:
# mitigate path
new_path = join(abspath(join(dirname(old_file), "..", "..")), basename(old_file))
config = current_execution_context()
with patch.object(folder_paths, "__file__", new_path), \
patch.object(sys.modules['nodes'], "EXTENSION_WEB_DIRS", {}, create=True): # mitigate JS copy
block_installation = config and config.configuration and config.configuration.block_runtime_package_installation
with (
patch.object(folder_paths, "__file__", new_path),
# mitigate packages installing things dynamically
patch_pip_install_subprocess_run() if block_installation else nullcontext(),
patch_pip_install_popen() if block_installation else nullcontext(),
):
yield ExportedNodes()
finally:
# todo: mitigate "/manager/reboot"
@ -263,6 +275,8 @@ def mitigated_import_of_vanilla_custom_nodes() -> ExportedNodes:
# this mitigation puts files that custom nodes expects are at the root of the repository back where they should be
# found. we're in the middle of executing the import of execution and server, in all likelihood, so like all things,
# the way community custom nodes is pretty radioactive
# there's a lot of subtle details here, and unfortunately, once this is called, there are some things that have
# to be activated later, in different places, to make all the hacks necessary for custom nodes to work
prepare_vanilla_environment()
from ..cmd import folder_paths

View File

@ -1,9 +1,5 @@
# todo: this should be defined in a common place, the fact that the nodes are imported by execution the way that they are is pretty radioactive
import collections.abc
import sys
import threading
from contextlib import contextmanager
from unittest.mock import patch
import lazy_object_proxy
@ -28,29 +24,3 @@ def get_nodes() -> ExportedNodes:
if len(current_ctx.custom_nodes) == 0:
return nodes
return exported_nodes_view(nodes, current_ctx.custom_nodes)
class _NodeClassMappingsShim(collections.abc.Mapping):
def __iter__(self):
for key in get_nodes().NODE_CLASS_MAPPINGS:
yield key
def __getitem__(self, item):
return get_nodes().NODE_CLASS_MAPPINGS[item]
def __len__(self):
return len(get_nodes().NODE_CLASS_MAPPINGS)
# todo: does this need to be mutable?
@contextmanager
def vanilla_node_execution_environment():
# check if we're running with patched nodes
if 'nodes' in sys.modules:
# this ensures NODE_CLASS_MAPPINGS is loaded lazily and contains all the nodes loaded so far, not just the base nodes
# easy-use and other nodes expect NODE_CLASS_MAPPINGS to contain all the nodes in the environment
with patch('nodes.NODE_CLASS_MAPPINGS', _NodeClassMappingsShim()):
yield
else:
yield

View File

@ -1,16 +1,103 @@
from __future__ import annotations
import collections.abc
import contextvars
import logging
import subprocess
import sys
import types
from contextlib import contextmanager, nullcontext
from functools import partial
from importlib.util import find_spec
from pathlib import Path
from threading import RLock
from typing import Dict
import wrapt
logger = logging.getLogger(__name__)
# there isn't a way to do this per-thread, it's only per process, so the global is valid
# we don't want some kind of multiprocessing lock, because this is munging the sys.modules
# wrapt.synchronized will be used to synchronize this
_in_environment = False
class _NodeClassMappingsShim(collections.abc.Mapping):
def __init__(self):
super().__init__()
self._active = 0
self._active_lock = RLock()
def activate(self):
with self._active_lock:
self._active += 1
def deactivate(self):
with self._active_lock:
self._active -= 1
def __iter__(self):
if self._active > 0:
from comfy.nodes_context import get_nodes
for key in get_nodes().NODE_CLASS_MAPPINGS:
yield key
else:
from comfy.nodes.base_nodes import NODE_CLASS_MAPPINGS
for key in NODE_CLASS_MAPPINGS:
yield key
def __getitem__(self, item):
if self._active > 0:
from comfy.nodes_context import get_nodes
return get_nodes().NODE_CLASS_MAPPINGS[item]
else:
from comfy.nodes.base_nodes import NODE_CLASS_MAPPINGS
return NODE_CLASS_MAPPINGS[item]
def __len__(self):
if self._active > 0:
from comfy.nodes_context import get_nodes
return len(get_nodes().NODE_CLASS_MAPPINGS)
else:
from comfy.nodes.base_nodes import NODE_CLASS_MAPPINGS
return len(NODE_CLASS_MAPPINGS)
# todo: does this need to be mutable?
class _NodeShim:
def __init__(self):
self.__name__ = 'nodes'
self.__package__ = ''
nodes_file = None
try:
# the 'nodes' module is expected to be in the directory above 'comfy'
spec = find_spec('comfy')
if spec and spec.origin:
comfy_package_path = Path(spec.origin).parent
nodes_module_dir = comfy_package_path.parent
nodes_file = str(nodes_module_dir / 'nodes.py')
except (ImportError, AttributeError):
# don't do anything exotic
pass
self.__file__ = nodes_file
self.__loader__ = None
self.__spec__ = None
def __node_class_mappings(self) -> _NodeClassMappingsShim:
return getattr(self, "NODE_CLASS_MAPPINGS")
def activate(self):
self.__node_class_mappings().activate()
def deactivate(self):
self.__node_class_mappings().deactivate()
@wrapt.synchronized
def prepare_vanilla_environment():
global _in_environment
if _in_environment:
@ -38,7 +125,22 @@ def prepare_vanilla_environment():
for module in (cuda_malloc, folder_paths, latent_preview, node_helpers, protocol):
module_short_name = module.__name__.split(".")[-1]
sys.modules[module_short_name] = module
sys.modules['nodes'] = base_nodes
# easy-use needs a shim
# this ensures NODE_CLASS_MAPPINGS is loaded lazily and contains all the nodes loaded so far, not just the base nodes
# easy-use and other nodes expect NODE_CLASS_MAPPINGS to contain all the nodes in the environment
# the shim must be activated after importing, which happens in a tightly coupled way
# todo: it's not clear if we should skip the dunder methods or not
nodes_shim_dir = {k: getattr(base_nodes, k) for k in dir(base_nodes) if not k.startswith("__")}
nodes_shim_dir['NODE_CLASS_MAPPINGS'] = _NodeClassMappingsShim()
nodes_shim_dir['EXTENSION_WEB_DIRS'] = {}
nodes_shim = _NodeShim()
for k, v in nodes_shim_dir.items():
setattr(nodes_shim, k, v)
sys.modules['nodes'] = nodes_shim
comfyui_version = types.ModuleType('comfyui_version', '')
setattr(comfyui_version, "__version__", __version__)
sys.modules['comfyui_version'] = comfyui_version
@ -76,3 +178,97 @@ def prepare_vanilla_environment():
threading.Thread.start = patched_start
setattr(threading.Thread.start, '__is_patched_by_us', True)
logger.debug("Patched `threading.Thread.start` to propagate contextvars.")
@contextmanager
def patch_pip_install_subprocess_run():
from unittest.mock import patch, MagicMock
original_subprocess_run = subprocess.run
def custom_side_effect(*args, **kwargs):
command_list = args[0] if args else []
# from easy-use
is_pip_install_call = (
isinstance(command_list, list) and
len(command_list) == 6 and
command_list[0] == sys.executable and
command_list[1] == '-s' and
command_list[2] == '-m' and
command_list[3] == 'pip' and
command_list[4] == 'install' and
isinstance(command_list[5], str)
)
if is_pip_install_call:
package_name = command_list[5]
logger.info(f"Intercepted and mocked `pip install` for: {package_name}")
mock_result = MagicMock()
mock_result.returncode = 0
return mock_result
else:
return original_subprocess_run(*args, **kwargs)
with patch('subprocess.run') as mock_run:
mock_run.side_effect = custom_side_effect
yield
@contextmanager
def patch_pip_install_popen():
from unittest.mock import patch, MagicMock
original_subprocess_popen = subprocess.Popen
def custom_side_effect(*args, **kwargs):
command_list = args[0] if args else []
is_pip_install_call = (
isinstance(command_list, list) and
len(command_list) >= 5 and
command_list[0] == sys.executable and
command_list[1] == "-m" and
command_list[2] == "pip" and
command_list[3] == "install" and
# special case nunchaku
"nunchaku" not in command_list[4:]
)
if is_pip_install_call:
package_names = command_list[4:]
logger.info(f"Intercepted and mocked `subprocess.Popen` for: pip install {' '.join(package_names)}")
mock_popen_instance = MagicMock()
# make stdout and stderr empty iterables so loops over them complete immediately.
mock_popen_instance.stdout = []
mock_popen_instance.stderr = []
return mock_popen_instance
else:
return original_subprocess_popen(*args, **kwargs)
with patch('subprocess.Popen') as mock_popen:
mock_popen.side_effect = custom_side_effect
yield mock_popen
@contextmanager
def vanilla_environment_node_execution_hooks():
# this handles activating the NODE_CLASS_MAPPINGS shim
from comfy.execution_context import current_execution_context
ctx = current_execution_context()
if 'nodes' in sys.modules and isinstance(sys.modules['nodes'], _NodeShim):
nodes_shim: _NodeShim = sys.modules['nodes']
try:
nodes_shim.activate()
block_installs = ctx and ctx.configuration and ctx.configuration.block_runtime_package_installation is True
with (
patch_pip_install_subprocess_run() if block_installs else nullcontext(),
patch_pip_install_popen() if block_installs else nullcontext(),
):
yield
finally:
nodes_shim.deactivate()
else:
yield

View File

@ -27,7 +27,7 @@ dependencies = [
"torchsde>=0.2.6",
"einops>=0.6.0",
"open-clip-torch>=2.24.0",
"transformers>=4.46.0,!=4.53.0,!=4.53.1,!=4.53.2",
"transformers>=4.46.0,!=4.53.0,!=4.53.1,!=4.53.2,!=4.57.0",
"tokenizers>=0.13.3",
"sentencepiece",
"peft>=0.10.0",
@ -52,7 +52,7 @@ dependencies = [
"tqdm",
"protobuf>=3.20.0,<5.0.0",
"psutil",
"ConfigArgParse",
"ConfigArgParse>=1.7.1",
"aio-pika",
"pyjwt[crypto]",
"kornia>=0.7.0",
@ -113,6 +113,8 @@ dependencies = [
"stringzilla<4.2.0",
"requests_cache",
"universal_pathlib",
# yanked propcache is omitted
"propcache!=0.4.0",
]
[build-system]

View File

@ -13,6 +13,9 @@ from typing import List, Any, Generator
import pytest
import requests
from comfy.cli_args import default_configuration
from comfy.execution_context import context_configuration
os.environ['OTEL_METRICS_EXPORTER'] = 'none'
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
os.environ["HF_XET_HIGH_PERFORMANCE"] = "True"
@ -27,11 +30,8 @@ logging.getLogger("aio_pika").setLevel(logging.CRITICAL + 1)
def run_server(server_arguments: Configuration):
from comfy.cmd.main import _start_comfyui
from comfy.cli_args import args
import asyncio
for arg, value in server_arguments.items():
args[arg] = value
asyncio.run(_start_comfyui())
asyncio.run(_start_comfyui(configuration=server_arguments))
@pytest.fixture(scope="function", autouse=False)
@ -140,7 +140,7 @@ def comfy_background_server(tmp_path_factory) -> Generator[tuple[Configuration,
tmp_path = tmp_path_factory.mktemp("comfy_background_server")
# Start server
configuration = Configuration()
configuration = default_configuration()
configuration.listen = "localhost"
configuration.output_directory = str(tmp_path)
configuration.input_directory = str(tmp_path)

View File

@ -0,0 +1,288 @@
import os
import shutil
import subprocess
import sys
from threading import Thread, Barrier
from pathlib import Path
import asyncio
import contextvars
from concurrent.futures import ThreadPoolExecutor
import pytest
from testcontainers.core.container import DockerContainer
from testcontainers.core.wait_strategies import LogMessageWaitStrategy
from comfy.component_model.file_counter import FileCounter
from comfy.component_model.folder_path_types import FolderNames
from comfy.execution_context import context_folder_names_and_paths
from comfy.cmd.folder_paths import init_default_paths
def is_tool(name):
"""Check whether `name` is on PATH and marked as executable."""
return shutil.which(name) is not None
def run_command(command, check=True):
"""Helper to run a shell command."""
try:
return subprocess.run(command, shell=True, check=check, capture_output=True, text=True)
except subprocess.CalledProcessError as e:
print(f"Command failed: {command}")
print(f"--- STDOUT ---\n{e.stdout}")
print(f"--- STDERR ---\n{e.stderr}")
print("--------------")
raise
@pytest.fixture(
params=[
pytest.param("local", id="local_filesystem"),
pytest.param(
"nfs",
id="nfs_share",
marks=pytest.mark.skipif(
not sys.platform.startswith("linux")
or not is_tool("mount.nfs")
or not is_tool("sudo")
or not os.path.exists("/sys/module/nfsd"),
reason="NFS tests require sudo, nfs-common, and the 'nfsd' kernel module to be loaded.",
),
),
pytest.param(
"samba",
id="samba_share",
marks=pytest.mark.skipif(
not sys.platform.startswith("linux")
or not is_tool("mount.cifs")
or not is_tool("sudo"),
reason="Samba tests require sudo on Linux with cifs-utils installed.",
),
),
]
)
def counter_path_factory(request, tmp_path_factory):
"""A parameterized fixture to provide paths on local, NFS, and Samba filesystems."""
if request.param == "local":
yield lambda name: str(tmp_path_factory.mktemp("local_test") / name)
return
mount_point = tmp_path_factory.mktemp(f"mount_point_{request.param}")
if request.param == "nfs":
# 1. Create the host directory that will be mounted into the container.
nfs_source = tmp_path_factory.mktemp("nfs_source")
# 2. FIX: Set permissions on the *host* directory.
os.chmod(str(nfs_source), 0o777)
# 3. FIX: Use the new container's required path: /mnt/data
container_path = "/mnt/data"
# 4. FIX: Change to the new container image and configuration
nfs_container = DockerContainer("ghcr.io/normal-computing/nfs-server:latest").with_env(
"NFS_SERVER_ALLOWED_CLIENTS", "*"
).with_kwargs(privileged=True).with_exposed_ports(2049).with_volume_mapping(
str(nfs_source), container_path, mode="rw" # Mount to /mnt/data
)
nfs_container.start()
# 5. FIX: Wait for the new container's export log message
# (and remove the timeout as requested)
nfs_container.waiting_for(
LogMessageWaitStrategy(r"exporting /mnt/data")
)
request.addfinalizer(lambda: nfs_container.stop())
ip_address = nfs_container.get_container_host_ip()
nfs_port = nfs_container.get_exposed_port(2049)
try:
# 6. FIX: Mount using the new container's command format
# (mounts root ":" and uses "-t nfs4")
run_command(f"sleep 1 && sudo mount -t nfs4 -o proto=tcp,port={nfs_port} {ip_address}:/ {mount_point}")
yield lambda name: str(mount_point / name)
finally:
run_command(f"sudo umount {mount_point}", check=False)
elif request.param == "samba":
# 1. Create the host directory.
samba_source = tmp_path_factory.mktemp("samba_source")
# 2. Set permissions on the *host* directory.
os.chmod(str(samba_source), 0o777)
share_name = "storage"
# 3. FIX: Add the NAME environment variable to tell the container
# to create a share named "storage".
samba_container = DockerContainer("dockurr/samba:latest").with_env(
"RW", "yes"
).with_env(
"NAME", share_name # <-- This is the crucial fix
).with_exposed_ports(445).with_volume_mapping(
str(samba_source), "/storage", mode="rw" # This maps the host dir to the internal /storage path
)
samba_container.start()
# 4. Wait for the correct log message
# (and remove the timeout as requested)
samba_container.waiting_for(
LogMessageWaitStrategy(r"smbd version .* started")
)
request.addfinalizer(lambda: samba_container.stop())
ip_address = samba_container.get_container_host_ip()
samba_port = samba_container.get_exposed_port(445)
try:
# 5. FIX: Mount with the default username/password, not as guest.
run_command(f"sleep 1 && sudo mount -t cifs -o username=samba,password=secret,vers=3.0,port={samba_port},uid=$(id -u),gid=$(id -g) //{ip_address}/{share_name} {mount_point}", check=True)
yield lambda name: str(mount_point / name)
finally:
run_command(f"sudo umount {mount_point}", check=False)
def test_initial_state(counter_path_factory):
"""Verify initial state and file creation."""
counter_file = counter_path_factory("counter.txt")
lock_file = counter_path_factory("counter.txt.lock")
assert not os.path.exists(counter_file)
assert not os.path.exists(lock_file)
counter = FileCounter(str(counter_file))
assert counter.get_and_increment() == 0
assert os.path.exists(counter_file)
with open(counter_file, "r") as f:
assert f.read() == "1"
def test_mkdirs(counter_path_factory):
counter_file = counter_path_factory("new_dir/counter.txt")
assert not os.path.exists(os.path.dirname(counter_file))
counter = FileCounter(str(counter_file))
assert counter.get_and_increment() == 0
assert counter.get_and_increment() == 1
def test_increment_and_decrement(counter_path_factory):
"""Test the increment and decrement logic."""
counter_file = counter_path_factory("counter.txt")
counter = FileCounter(str(counter_file))
assert counter.get_and_increment() == 0 # val: 0, new_val: 1
assert counter.get_and_increment() == 1 # val: 1, new_val: 2
assert counter.get_and_increment() == 2 # val: 2, new_val: 3
assert counter.decrement_and_get() == 2 # val: 3, new_val: 2
assert counter.decrement_and_get() == 1 # val: 2, new_val: 1
assert counter.get_and_increment() == 1 # val: 1, new_val: 2
def test_multiple_instances_same_path(counter_path_factory):
"""Verify that multiple FileCounter instances on the same path work correctly."""
counter_file = counter_path_factory("counter.txt")
counter1 = FileCounter(str(counter_file))
counter2 = FileCounter(str(counter_file))
assert counter1.get_and_increment() == 0
assert counter2.get_and_increment() == 1
assert counter1.decrement_and_get() == 1
assert counter2.get_and_increment() == 1
with open(counter_file, "r") as f:
assert f.read() == "2"
def test_multithreaded_access(counter_path_factory):
"""Ensure atomicity with multiple threads."""
counter_file = counter_path_factory("counter.txt")
counter = FileCounter(str(counter_file))
num_threads = 10
increments_per_thread = 100
def worker():
for _ in range(increments_per_thread):
counter.get_and_increment()
threads = [Thread(target=worker) for _ in range(num_threads)]
for t in threads:
t.start()
for t in threads:
t.join()
with open(counter_file, "r") as f:
final_value = int(f.read())
assert final_value == num_threads * increments_per_thread
def test_context_manager(counter_path_factory):
"""Test that the counter can be used as a context manager."""
counter_file = counter_path_factory("counter.txt")
counter = FileCounter(str(counter_file))
# Initial state should be 0
assert counter.get_and_increment() == 0
with open(counter_file) as f:
assert f.read() == "1"
with counter as wrapper:
# The wrapper's value is the original count before increment.
assert wrapper.value == 1
# It can also be used as an integer directly.
assert int(wrapper) == 1
with open(counter_file) as f:
assert f.read() == "2" # Inside context, value is 2
with open(counter_file) as f:
assert f.read() == "1" # Exited context, decremented back to 1
# After exit, the wrapper's 'ctr' attribute holds the new value.
assert wrapper.ctr == 1
async def test_cleanup_temp_multithreaded(tmp_path):
"""
Test that cleanup_temp correctly deletes the temp directory only
after the last thread has exited the context.
"""
# 1. Use the application's context to define the temp directory for this test.
# This is a cleaner approach than mocking.
base_dir = tmp_path / "base"
temp_dir_override = base_dir / "temp"
fn = FolderNames(base_paths=[base_dir])
init_default_paths(fn, base_paths_from_configuration=False)
# Override the default temp path
fn.temp_directory = temp_dir_override
from comfy.component_model.file_counter import cleanup_temp
num_threads = 5
# Barrier to make threads wait for each other before exiting.
barrier = Barrier(num_threads)
loop = asyncio.get_running_loop()
def worker():
"""The task for each thread. It enters the cleanup context and waits."""
with cleanup_temp():
# The temp directory and counter file should exist inside the context.
assert temp_dir_override.exists()
assert (temp_dir_override / "counter.txt").exists()
# After exiting, the directory should still exist until the last thread is done.
# Use the context manager to set the folder paths for the current async context.
with context_folder_names_and_paths(fn):
# Capture the current context, which now includes the folder_paths settings.
# Run the worker function in a thread pool, applying the captured context to each task.
with ThreadPoolExecutor(max_workers=num_threads) as executor:
tasks = [loop.run_in_executor(executor, contextvars.copy_context().run, worker) for _ in range(num_threads)]
# Wait for all threads to complete their work.
await asyncio.gather(*tasks)
# After all threads have joined, the counter should be 0, and the directory deleted.
assert not temp_dir_override.exists()
# the base dir is not going to be deleted
assert base_dir.exists()

View File

@ -16,6 +16,7 @@ from typing import List, Dict, Any, Generator
from PIL import Image
from comfy.cli_args import default_configuration
from comfy.cli_args_types import Configuration
from comfy_execution.graph_utils import GraphBuilder
from .test_execution import ComfyClient, RunResult
@ -188,7 +189,7 @@ class TestProgressIsolation:
tmp_path = tmp_path_factory.mktemp("comfy_background_server")
# Start server
configuration = Configuration()
configuration = default_configuration()
configuration.listen = args_pytest["listen"]
configuration.port = args_pytest["port"]
configuration.cpu = True
@ -205,7 +206,7 @@ testing_pack:
yaml_path = str(tmp_path_factory.mktemp("comfy_background_server") / "extra_nodes.yaml")
with open(yaml_path, mode="wt") as f:
f.write(extra_nodes)
configuration.extra_model_paths_config = [str(yaml_path)]
configuration.extra_model_paths_config = [yaml_path]
yield from comfy_background_server_from_config(configuration)

View File

@ -1,5 +1,6 @@
import importlib.resources
import json
import logging
from importlib.abc import Traversable
import pytest
@ -11,6 +12,8 @@ from comfy.model_downloader_types import CivitFile, HuggingFile
from comfy_extras.nodes.nodes_audio import TorchAudioNotFoundError
from . import workflows
logger = logging.getLogger(__name__)
@pytest.fixture(scope="function", autouse=False)
async def client(tmp_path_factory) -> Comfy:
@ -35,6 +38,7 @@ async def test_workflow(workflow_name: str, workflow_file: Traversable, has_gpu:
prompt = Prompt.validate(workflow)
# todo: add all the models we want to test a bit m2ore elegantly
outputs = {}
try:
outputs = await client.queue_prompt(prompt)
except TorchAudioNotFoundError:
@ -46,8 +50,14 @@ async def test_workflow(workflow_name: str, workflow_file: Traversable, has_gpu:
elif any(v.class_type == "SaveAudio" for v in prompt.values()):
save_audio_node_id = next(key for key in prompt if prompt[key].class_type == "SaveAudio")
assert outputs[save_audio_node_id]["audio"][0]["filename"] is not None
elif any(v.class_type == "SaveAnimatedWEBP" for v in prompt.values()):
save_video_node_id = next(key for key in prompt if prompt[key].class_type == "SaveAnimatedWEBP")
assert outputs[save_video_node_id]["images"][0]["filename"] is not None
elif any(v.class_type == "PreviewString" for v in prompt.values()):
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "PreviewString")
output_str = outputs[save_image_node_id]["string"][0]
assert output_str is not None
assert len(output_str) > 0
else:
assert len(outputs) > 0
logger.warning(f"test {workflow_name} did not have a node that could be checked for output")

7098
uv.lock

File diff suppressed because one or more lines are too long