packaging fixes

- enable user db
 - fix main_pre order everywhere
 - fix absolute to relative imports everywhere
 - async better supported
This commit is contained in:
doctorpangloss 2025-07-15 10:19:33 -07:00
parent c086c5e005
commit 96b4e04315
39 changed files with 604 additions and 151 deletions

View File

@ -82,7 +82,7 @@ limit-inference-results=100
# List of plugins (as comma separated values of python module names) to load, # List of plugins (as comma separated values of python module names) to load,
# usually to register additional checkers. # usually to register additional checkers.
load-plugins=tests.absolute_import_checker load-plugins=tests.absolute_import_checker,tests.main_pre_import_checker
# Pickle collected data for later comparisons. # Pickle collected data for later comparisons.
persistent=yes persistent=yes
@ -678,7 +678,7 @@ disable=raw-checker-failed,
# either give multiple identifier separated by comma (,) or put this option # either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where # multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples. # it should appear only once). See also the "--disable" option for examples.
enable=absolute-import-used enable=
[METHOD_ARGS] [METHOD_ARGS]

View File

@ -13,7 +13,7 @@ script_location = alembic_db
# sys.path path, will be prepended to sys.path if present. # sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory. # defaults to the current working directory.
prepend_sys_path = . # prepend_sys_path = .
# timezone to use when rendering the date within the migration file # timezone to use when rendering the date within the migration file
# as well as the filename. # as well as the filename.
@ -63,7 +63,7 @@ version_path_separator = os
# are written from script.py.mako # are written from script.py.mako
# output_encoding = utf-8 # output_encoding = utf-8
sqlalchemy.url = sqlite:///user/comfyui.db sqlalchemy.url = sqlite:///./comfyui.db
[post_write_hooks] [post_write_hooks]

View File

View File

@ -1,3 +1,4 @@
# pylint: disable=no-member
from sqlalchemy import engine_from_config from sqlalchemy import engine_from_config
from sqlalchemy import pool from sqlalchemy import pool
@ -7,8 +8,7 @@ from alembic import context
# access to the values within the .ini file in use. # access to the values within the .ini file in use.
config = context.config config = context.config
from ..app.database.models import Base
from comfy.app.database.models import Base
target_metadata = Base.metadata target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py, # other values from the config, defined by the needs of env.py,

View File

@ -1,8 +1,10 @@
import logging import logging
import os import os
import shutil import shutil
from importlib.resources import files
from ...cli_args import args from ...cli_args import args
from ...component_model.files import get_package_as_path
Session = None Session = None
@ -15,6 +17,7 @@ from sqlalchemy.orm import sessionmaker
_DB_AVAILABLE = True _DB_AVAILABLE = True
logger = logging.getLogger(__name__)
def dependencies_available(): def dependencies_available():
""" """
@ -32,9 +35,8 @@ def can_create_session():
def get_alembic_config(): def get_alembic_config():
root_path = os.path.join(os.path.dirname(__file__), "../..") config_path = str(files("comfy") / "alembic.ini")
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini")) scripts_path = get_package_as_path("comfy.alembic_db")
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
config = Config(config_path) config = Config(config_path)
config.set_main_option("script_location", scripts_path) config.set_main_option("script_location", scripts_path)
@ -53,7 +55,7 @@ def get_db_path():
def init_db(): def init_db():
db_url = args.database_url db_url = args.database_url
logging.debug(f"Database URL: {db_url}") logger.debug(f"Database URL: {db_url}")
db_path = get_db_path() db_path = get_db_path()
db_exists = os.path.exists(db_path) db_exists = os.path.exists(db_path)
@ -70,7 +72,7 @@ def init_db():
target_rev = script.get_current_head() target_rev = script.get_current_head()
if target_rev is None: if target_rev is None:
logging.warning("No target revision found.") logger.debug("No target revision found.")
elif current_rev != target_rev: elif current_rev != target_rev:
# Backup the database pre upgrade # Backup the database pre upgrade
backup_path = db_path + ".bkp" backup_path = db_path + ".bkp"
@ -81,13 +83,13 @@ def init_db():
try: try:
command.upgrade(config, target_rev) command.upgrade(config, target_rev)
logging.info(f"Database upgraded from {current_rev} to {target_rev}") logger.info(f"Database upgraded from {current_rev} to {target_rev}")
except Exception as e: except Exception as e:
if backup_path: if backup_path:
# Restore the database from backup if upgrade fails # Restore the database from backup if upgrade fails
shutil.copy(backup_path, db_path) shutil.copy(backup_path, db_path)
os.remove(backup_path) os.remove(backup_path)
logging.exception("Error upgrading database: ") logger.exception("Error upgrading database: ")
raise e raise e
global Session global Session

View File

@ -5,7 +5,8 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch from ..ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, \
get_2d_sincos_pos_embed_torch
class ControlNetEmbedder(nn.Module): class ControlNetEmbedder(nn.Module):

View File

@ -11,7 +11,7 @@ import configargparse as argparse
from . import __version__ from . import __version__
from . import options from . import options
from .cli_args_types import LatentPreviewMethod, Configuration, ConfigurationExtender, EnumAction, \ from .cli_args_types import LatentPreviewMethod, Configuration, ConfigurationExtender, EnumAction, \
EnhancedConfigArgParser, PerformanceFeature, is_valid_directory EnhancedConfigArgParser, PerformanceFeature, is_valid_directory, db_config
# todo: move this # todo: move this
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest" DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
@ -261,8 +261,8 @@ def _create_parser() -> EnhancedConfigArgParser:
help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)", help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
) )
parser.add_argument("--database-url", type=str, default=f"sqlite:///:memory:", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") default_db_url = db_config()
parser.add_argument("--database-url", type=str, default=default_db_url, help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
parser.add_argument("--workflows", type=str, nargs='+', default=[], help="Execute the API workflow(s) specified in the provided files. For each workflow, its outputs will be printed to a line to standard out. Application logging will be redirected to standard error. Use `-` to signify standard in.") parser.add_argument("--workflows", type=str, nargs='+', default=[], help="Execute the API workflow(s) specified in the provided files. For each workflow, its outputs will be printed to a line to standard out. Application logging will be redirected to standard error. Use `-` to signify standard in.")
# now give plugins a chance to add configuration # now give plugins a chance to add configuration

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import enum import enum
import logging
import os import os
from typing import Optional, List, Callable, Any, Union, Mapping, NamedTuple from typing import Optional, List, Callable, Any, Union, Mapping, NamedTuple
@ -20,6 +21,22 @@ class LatentPreviewMethod(enum.Enum):
ConfigObserver = Callable[[str, Any], None] ConfigObserver = Callable[[str, Any], None]
def db_config() -> str:
from .vendor.appdirs import user_data_dir
logger = logging.getLogger(__name__)
try:
data_dir = user_data_dir(appname="comfyui")
os.makedirs(data_dir, exist_ok=True)
db_path = os.path.join(data_dir, "comfy.db")
default_db_url = f"sqlite:///{db_path}"
except Exception as e:
# Fallback to an in-memory database if the user directory can't be accessed
logger.warning(f"Could not determine user data directory for database, falling back to in-memory: {e}")
default_db_url = "sqlite:///:memory:"
return default_db_url
def is_valid_directory(path: str) -> str: def is_valid_directory(path: str) -> str:
"""Validate if the given path is a directory, and check permissions.""" """Validate if the given path is a directory, and check permissions."""
if not os.path.exists(path): if not os.path.exists(path):
@ -261,7 +278,7 @@ class Configuration(dict):
self.front_end_version: str = "comfyanonymous/ComfyUI@latest" self.front_end_version: str = "comfyanonymous/ComfyUI@latest"
self.front_end_root: Optional[str] = None self.front_end_root: Optional[str] = None
self.comfy_api_base: str = "https://api.comfy.org" self.comfy_api_base: str = "https://api.comfy.org"
self.database_url: str = "sqlite:///:memory:" self.database_url: str = db_config()
for key, value in kwargs.items(): for key, value in kwargs.items():
self[key] = value self[key] = value

View File

@ -14,11 +14,11 @@ from opentelemetry import context, propagate
from opentelemetry.context import Context, attach, detach from opentelemetry.context import Context, attach, detach
from opentelemetry.trace import Status, StatusCode from opentelemetry.trace import Status, StatusCode
from ..cmd.main_pre import tracer
from .client_types import V1QueuePromptResponse from .client_types import V1QueuePromptResponse
from ..api.components.schema.prompt import PromptDict from ..api.components.schema.prompt import PromptDict
from ..cli_args_types import Configuration from ..cli_args_types import Configuration
from ..cmd.folder_paths import init_default_paths # pylint: disable=import-error from ..cmd.folder_paths import init_default_paths # pylint: disable=import-error
from ..cmd.main_pre import tracer
from ..component_model.executor_types import ExecutorToClientProgress from ..component_model.executor_types import ExecutorToClientProgress
from ..component_model.make_mutable import make_mutable from ..component_model.make_mutable import make_mutable
from ..distributed.executors import ContextVarExecutor from ..distributed.executors import ContextVarExecutor
@ -97,7 +97,7 @@ async def __execute_prompt(
else: else:
prompt_executor.server = progress_handler prompt_executor.server = progress_handler
prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id}, await prompt_executor.execute_async(prompt_mut, prompt_id, {"client_id": client_id},
execute_outputs=validation_tuple.good_output_node_ids) execute_outputs=validation_tuple.good_output_node_ids)
return prompt_executor.outputs_ui return prompt_executor.outputs_ui
except Exception as exc_info: except Exception as exc_info:
@ -195,7 +195,7 @@ class Comfy:
if isinstance(prompt, str): if isinstance(prompt, str):
prompt = json.loads(prompt) prompt = json.loads(prompt)
if isinstance(prompt, dict): if isinstance(prompt, dict):
from comfy.api.components.schema.prompt import Prompt from ..api.components.schema.prompt import Prompt
prompt = Prompt.validate(prompt) prompt = Prompt.validate(prompt)
outputs = await self.queue_prompt(prompt) outputs = await self.queue_prompt(prompt)
return V1QueuePromptResponse(urls=[], outputs=outputs) return V1QueuePromptResponse(urls=[], outputs=outputs)

View File

@ -19,14 +19,15 @@ from typing import List, Optional, Tuple, Literal
import torch import torch
from opentelemetry.trace import get_current_span, StatusCode, Status from opentelemetry.trace import get_current_span, StatusCode, Status
# order matters
from .main_pre import tracer
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, \ from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, \
DependencyAwareCache, \ DependencyAwareCache, \
BasicCache BasicCache
# order matters
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy_execution.graph_utils import is_link, GraphBuilder from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.utils import CurrentNodeContext from comfy_execution.utils import CurrentNodeContext
from .main_pre import tracer
from .. import interruption from .. import interruption
from .. import model_management from .. import model_management
from ..cli_args import args from ..cli_args import args
@ -37,7 +38,7 @@ from ..component_model.executor_types import ExecutorToClientProgress, Validatio
HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage
from ..component_model.files import canonicalize_path from ..component_model.files import canonicalize_path
from ..component_model.module_property import create_module_properties from ..component_model.module_property import create_module_properties
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus, ExecutionStatusAsDict
from ..execution_context import context_execute_node, context_execute_prompt from ..execution_context import context_execute_node, context_execute_prompt
from ..execution_ext import should_panic_on_exception from ..execution_ext import should_panic_on_exception
from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
@ -388,7 +389,7 @@ def format_value(x) -> FormattedValue:
return str(x.__class__) return str(x.__class__)
async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches, _node_id: str, extra_data: dict, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple: async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches, node_id: str, extra_data: dict, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple:
""" """
:param server: :param server:
@ -402,8 +403,8 @@ async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, ca
:param pending_subgraph_results: :param pending_subgraph_results:
:return: :return:
""" """
with context_execute_node(_node_id): with context_execute_node(node_id):
return _execute(server, dynprompt, caches, _node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) return await _execute(server, dynprompt, caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple: async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple:
@ -516,7 +517,7 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
unblock() unblock()
asyncio.create_task(await_completion()) asyncio.create_task(await_completion())
return (ExecutionResult.PENDING, None, None) return RecursiveExecutionTuple(ExecutionResult.PENDING, None, None)
if len(output_ui) > 0: if len(output_ui) > 0:
caches.ui.set(unique_id, { caches.ui.set(unique_id, {
"meta": { "meta": {
@ -685,11 +686,6 @@ class PromptExecutor:
if ex is not None and self.raise_exceptions: if ex is not None and self.raise_exceptions:
raise ex raise ex
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
asyncio_loop = asyncio.new_event_loop()
asyncio.set_event_loop(asyncio_loop)
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
# torchao and potentially other optimization approaches break when the models are created in inference mode # torchao and potentially other optimization approaches break when the models are created in inference mode
# todo: this should really be backpropagated to code which creates ModelPatchers via lazy evaluation rather than globally checked here # todo: this should really be backpropagated to code which creates ModelPatchers via lazy evaluation rather than globally checked here
@ -1206,7 +1202,7 @@ class PromptQueue(AbstractPromptQueue):
self.server.queue_updated() self.server.queue_updated()
return copy.deepcopy(item_with_future.queue_tuple), task_id return copy.deepcopy(item_with_future.queue_tuple), task_id
def task_done(self, item_id: str, outputs: dict, def task_done(self, item_id: str, outputs: HistoryResultDict,
status: Optional[ExecutionStatus]): status: Optional[ExecutionStatus]):
history_result = outputs history_result = outputs
with self.mutex: with self.mutex:
@ -1215,9 +1211,9 @@ class PromptQueue(AbstractPromptQueue):
if len(self.history) > MAXIMUM_HISTORY_SIZE: if len(self.history) > MAXIMUM_HISTORY_SIZE:
self.history.pop(next(iter(self.history))) self.history.pop(next(iter(self.history)))
status_dict: Optional[dict] = None status_dict = None
if status is not None: if status is not None:
status_dict = copy.deepcopy(ExecutionStatus(*status)._asdict()) status_dict: Optional[ExecutionStatusAsDict] = status.as_dict()
outputs_ = history_result["outputs"] outputs_ = history_result["outputs"]
# Remove sensitive data from extra_data before storing in history # Remove sensitive data from extra_data before storing in history
@ -1225,11 +1221,13 @@ class PromptQueue(AbstractPromptQueue):
if sensitive_val in prompt[3]: if sensitive_val in prompt[3]:
prompt[3].pop(sensitive_val) prompt[3].pop(sensitive_val)
self.history[prompt[1]] = { history_entry: HistoryEntry = {
"prompt": prompt, "prompt": prompt,
"outputs": copy.deepcopy(outputs_), "outputs": copy.deepcopy(outputs_),
'status': status_dict,
} }
if status_dict is not None:
history_entry["status"] = status_dict
self.history[prompt[1]] = history_entry
self.history[prompt[1]].update(history_result) self.history[prompt[1]].update(history_result)
self.server.queue_updated() self.server.queue_updated()
if queue_item.completed: if queue_item.completed:

View File

@ -10,20 +10,22 @@ import time
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from comfy.component_model.entrypoints_common import configure_application_paths, executor_from_args # main_pre must be the earliest import
from .main_pre import args
from . import hook_breaker_ac10a0 from . import hook_breaker_ac10a0
from .extra_model_paths import load_extra_path_config from .extra_model_paths import load_extra_path_config
# main_pre must be the earliest import since it suppresses some spurious warnings
from .main_pre import args
from .. import model_management from .. import model_management
from ..analytics.analytics import initialize_event_tracking from ..analytics.analytics import initialize_event_tracking
from ..cmd import cuda_malloc from ..cmd import cuda_malloc
from ..cmd import folder_paths from ..cmd import folder_paths
from ..cmd import server as server_module from ..cmd import server as server_module
from ..component_model.abstract_prompt_queue import AbstractPromptQueue from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.entrypoints_common import configure_application_paths, executor_from_args
from ..distributed.distributed_prompt_queue import DistributedPromptQueue from ..distributed.distributed_prompt_queue import DistributedPromptQueue
from ..distributed.server_stub import ServerStub from ..distributed.server_stub import ServerStub
from ..nodes.package import import_all_nodes_in_workspace from ..nodes.package import import_all_nodes_in_workspace
from ..nodes_context import get_nodes
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -42,6 +44,10 @@ def cuda_malloc_warning():
def prompt_worker(q: AbstractPromptQueue, server_instance: server_module.PromptServer): def prompt_worker(q: AbstractPromptQueue, server_instance: server_module.PromptServer):
asyncio.run(_prompt_worker(q, server_instance))
async def _prompt_worker(q: AbstractPromptQueue, server_instance: server_module.PromptServer):
from ..cmd import execution from ..cmd import execution
from ..component_model import queue_types from ..component_model import queue_types
from .. import model_management from .. import model_management
@ -68,7 +74,7 @@ def prompt_worker(q: AbstractPromptQueue, server_instance: server_module.PromptS
prompt_id = item[1] prompt_id = item[1]
server_instance.last_prompt_id = prompt_id server_instance.last_prompt_id = prompt_id
e.execute(item[2], prompt_id, item[3], item[4]) await e.execute_async(item[2], prompt_id, item[3], item[4])
need_gc = True need_gc = True
q.task_done(item_id, q.task_done(item_id,
e.history_result, e.history_result,
@ -174,17 +180,16 @@ async def _start_comfyui(from_script_dir: Optional[Path] = None):
for config_path in itertools.chain(*args.extra_model_paths_config): for config_path in itertools.chain(*args.extra_model_paths_config):
load_extra_path_config(config_path) load_extra_path_config(config_path)
# always create directories when started interactively
folder_paths.create_directories()
if args.create_directories: if args.create_directories:
# then, import and exit
import_all_nodes_in_workspace(raise_on_failure=False) import_all_nodes_in_workspace(raise_on_failure=False)
folder_paths.create_directories() folder_paths.create_directories()
exit(0) exit(0)
elif args.quick_test_for_ci:
setup_database() import_all_nodes_in_workspace(raise_on_failure=True)
exit(0)
if args.windows_standalone_build: if args.windows_standalone_build:
folder_paths.create_directories()
try: try:
from . import new_updater from . import new_updater
new_updater.update_windows_updater() new_updater.update_windows_updater()
@ -198,7 +203,7 @@ async def _start_comfyui(from_script_dir: Optional[Path] = None):
# at this stage, it's safe to import nodes # at this stage, it's safe to import nodes
hook_breaker_ac10a0.save_functions() hook_breaker_ac10a0.save_functions()
server.nodes = import_all_nodes_in_workspace() server.nodes = get_nodes()
hook_breaker_ac10a0.restore_functions() hook_breaker_ac10a0.restore_functions()
# as a side effect, this also populates the nodes for execution # as a side effect, this also populates the nodes for execution
@ -221,6 +226,7 @@ async def _start_comfyui(from_script_dir: Optional[Path] = None):
server.add_routes() server.add_routes()
cuda_malloc_warning() cuda_malloc_warning()
setup_database()
# in a distributed setting, the default prompt worker will not be able to send execution events via the websocket # in a distributed setting, the default prompt worker will not be able to send execution events via the websocket
worker_thread_server = server if not distributed else ServerStub() worker_thread_server = server if not distributed else ServerStub()
@ -254,14 +260,7 @@ async def _start_comfyui(from_script_dir: Optional[Path] = None):
logger.debug(f"Setting input directory to: {input_dir}") logger.debug(f"Setting input directory to: {input_dir}")
folder_paths.set_input_directory(input_dir) folder_paths.set_input_directory(input_dir)
if args.quick_test_for_ci: # now that nodes are loaded, create directories
# for CI purposes, try importing all the nodes
import_all_nodes_in_workspace(raise_on_failure=True)
return
else:
# we no longer lazily load nodes. we'll do it now for the sake of creating directories
import_all_nodes_in_workspace(raise_on_failure=False)
# now that nodes are loaded, create more directories if appropriate
folder_paths.create_directories() folder_paths.create_directories()
if len(args.workflows) > 0: if len(args.workflows) > 0:

View File

@ -1169,7 +1169,7 @@ class PromptServer(ExecutorToClientProgress):
await runner.setup() await runner.setup()
if 'tls_keyfile' in args or 'tls_certfile' in args: if 'tls_keyfile' in args or 'tls_certfile' in args:
raise ValueError("Use caddy instead of aiohttp to serve https by setting up a reverse proxy. See README.md") logger.warning("Use caddy instead of aiohttp to serve https by setting up a reverse proxy. See README.md")
def is_ipv4(address: str, *args): def is_ipv4(address: str, *args):
try: try:

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import typing import typing
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from .executor_types import HistoryResultDict
from .queue_types import QueueTuple, HistoryEntry, QueueItem, Flags, ExecutionStatus, TaskInvocation, AbstractPromptQueueGetCurrentQueueItems from .queue_types import QueueTuple, HistoryEntry, QueueItem, Flags, ExecutionStatus, TaskInvocation, AbstractPromptQueueGetCurrentQueueItems
@ -42,7 +43,7 @@ class AbstractPromptQueue(metaclass=ABCMeta):
pass pass
@abstractmethod @abstractmethod
def task_done(self, item_id: str, outputs: dict, def task_done(self, item_id: str, outputs: HistoryResultDict,
status: typing.Optional[ExecutionStatus]): status: typing.Optional[ExecutionStatus]):
""" """
Signals to the user interface that the task with the specified id is completed Signals to the user interface that the task with the specified id is completed

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import copy
from enum import Enum from enum import Enum
from typing import NamedTuple, Optional, List, Literal, Sequence from typing import NamedTuple, Optional, List, Literal, Sequence
from typing import Tuple from typing import Tuple
@ -24,6 +25,13 @@ class ExecutionStatus(NamedTuple):
completed: bool completed: bool
messages: List[str] messages: List[str]
def as_dict(self) -> ExecutionStatusAsDict:
return {
"status_str": self.status_str,
"completed": self.completed,
"messages": copy.copy(self.messages),
}
class ExecutionError(RuntimeError): class ExecutionError(RuntimeError):
def __init__(self, task_id: int | str, status: Optional[ExecutionStatus] = None, exceptions: Optional[Sequence[Exception]] = None, *args): def __init__(self, task_id: int | str, status: Optional[ExecutionStatus] = None, exceptions: Optional[Sequence[Exception]] = None, *args):

View File

@ -13,12 +13,12 @@ from aio_pika import connect_robust
from aio_pika.abc import AbstractConnection, AbstractChannel from aio_pika.abc import AbstractConnection, AbstractChannel
from aio_pika.patterns import JsonRPC from aio_pika.patterns import JsonRPC
from ..cmd.main_pre import tracer
from .distributed_progress import ProgressHandlers from .distributed_progress import ProgressHandlers
from .distributed_types import RpcRequest, RpcReply from .distributed_types import RpcRequest, RpcReply
from .history import History from .history import History
from .server_stub import ServerStub from .server_stub import ServerStub
from ..auth.permissions import jwt_decode from ..auth.permissions import jwt_decode
from ..cmd.main_pre import tracer
from ..cmd.server import PromptServer from ..cmd.server import PromptServer
from ..component_model.abstract_prompt_queue import AsyncAbstractPromptQueue, AbstractPromptQueue from ..component_model.abstract_prompt_queue import AsyncAbstractPromptQueue, AbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData, HistoryResultDict from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData, HistoryResultDict

View File

@ -10,12 +10,12 @@ from aio_pika.patterns import JsonRPC
from aiohttp import web from aiohttp import web
from aiormq import AMQPConnectionError from aiormq import AMQPConnectionError
from ..cmd.main_pre import tracer
from .executors import ContextVarExecutor from .executors import ContextVarExecutor
from .distributed_progress import DistributedExecutorToClientProgress from .distributed_progress import DistributedExecutorToClientProgress
from .distributed_types import RpcRequest, RpcReply from .distributed_types import RpcRequest, RpcReply
from .process_pool_executor import ProcessPoolExecutor from .process_pool_executor import ProcessPoolExecutor
from ..client.embedded_comfy_client import Comfy from ..client.embedded_comfy_client import Comfy
from ..cmd.main_pre import tracer
from ..component_model.queue_types import ExecutionStatus from ..component_model.queue_types import ExecutionStatus
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -2,7 +2,6 @@ import asyncio
from ..cmd.main_pre import args from ..cmd.main_pre import args
from ..component_model.entrypoints_common import configure_application_paths, executor_from_args from ..component_model.entrypoints_common import configure_application_paths, executor_from_args
from ..distributed.executors import ContextVarExecutor, ContextVarProcessPoolExecutor
async def main(): async def main():

View File

@ -1,7 +1,8 @@
import torch import torch
from comfy.text_encoders.bert import BertAttention
import comfy.model_management from ..ldm.modules.attention import optimized_attention_for_device
from comfy.ldm.modules.attention import optimized_attention_for_device from ..model_management import cast_to_device
from ..text_encoders.bert import BertAttention
class Dino2AttentionOutput(torch.nn.Module): class Dino2AttentionOutput(torch.nn.Module):
@ -29,7 +30,7 @@ class LayerScale(torch.nn.Module):
self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
def forward(self, x): def forward(self, x):
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype) return x * cast_to_device(self.lambda1, x.device, x.dtype)
class SwiGLUFFN(torch.nn.Module): class SwiGLUFFN(torch.nn.Module):
@ -117,7 +118,7 @@ class Dino2Embeddings(torch.nn.Module):
x = self.patch_embeddings(pixel_values) x = self.patch_embeddings(pixel_values)
# TODO: mask_token? # TODO: mask_token?
x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1) x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1)
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype) x = x + cast_to_device(self.position_embeddings, x.device, x.dtype)
return x return x

View File

@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from importlib.abc import Traversable # pylint: disable=no-name-in-module
from importlib.resources import files from importlib.resources import files
from pathlib import Path from pathlib import Path
@ -10,7 +9,7 @@ KNOWN_CHAT_TEMPLATES = {}
def _update_known_chat_templates(): def _update_known_chat_templates():
try: try:
_chat_templates: Traversable = files(__package__) / "chat_templates" _chat_templates = files(__package__) / "chat_templates"
_extra_jinja_templates = {Path(traversable.name).stem: traversable.read_text().replace(' ', '').replace('\n', '') for traversable in _chat_templates.iterdir() if traversable.is_file()} _extra_jinja_templates = {Path(traversable.name).stem: traversable.read_text().replace(' ', '').replace('\n', '') for traversable in _chat_templates.iterdir() if traversable.is_file()}
KNOWN_CHAT_TEMPLATES.update(_extra_jinja_templates) KNOWN_CHAT_TEMPLATES.update(_extra_jinja_templates)
except ImportError as exc: except ImportError as exc:

View File

@ -10,7 +10,7 @@ from transformers.tokenization_utils_base import TextInput, PreTokenizedInput, T
from transformers.utils import PaddingStrategy from transformers.utils import PaddingStrategy
from typing_extensions import TypedDict, NotRequired from typing_extensions import TypedDict, NotRequired
from comfy.component_model.tensor_types import RGBImageBatch from ..component_model.tensor_types import RGBImageBatch
class ProcessorResult(TypedDict): class ProcessorResult(TypedDict):

View File

@ -23,7 +23,7 @@ from einops import rearrange, repeat
from einops.layers.torch import Rearrange from einops.layers.torch import Rearrange
from torch import nn from torch import nn
from comfy.ldm.modules.attention import optimized_attention from ..modules.attention import optimized_attention
def get_normalization(name: str, channels: int, weight_args={}, operations=None): def get_normalization(name: str, channels: int, weight_args={}, operations=None):

View File

@ -11,7 +11,7 @@ import math
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
from torchvision import transforms from torchvision import transforms
from comfy.ldm.modules.attention import optimized_attention from ..modules.attention import optimized_attention
def apply_rotary_pos_emb( def apply_rotary_pos_emb(
t: torch.Tensor, t: torch.Tensor,

View File

@ -1,12 +1,6 @@
import torch import torch
from torch import nn from torch import nn
from comfy.ldm.flux.layers import ( from ..flux.layers import DoubleStreamBlock, LastLayer, MLPEmbedder, SingleStreamBlock, timestep_embedding
DoubleStreamBlock,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
)
class Hunyuan3Dv2(nn.Module): class Hunyuan3Dv2(nn.Module):

View File

@ -5,10 +5,10 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from comfy.ldm.modules.diffusionmodules.model import vae_attention from ..modules.diffusionmodules.model import vae_attention
import comfy.ops from ...ops import disable_weight_init
ops = comfy.ops.disable_weight_init ops = disable_weight_init
CACHE_T = 2 CACHE_T = 2

View File

@ -31,9 +31,9 @@ import psutil
import torch import torch
from opentelemetry.trace import get_current_span from opentelemetry.trace import get_current_span
from .cmd.main_pre import tracer
from . import interruption from . import interruption
from .cli_args import args, PerformanceFeature from .cli_args import args, PerformanceFeature
from .cmd.main_pre import tracer
from .component_model.deprecation import _deprecate_method from .component_model.deprecation import _deprecate_method
from .model_management_types import ModelManageable from .model_management_types import ModelManageable

View File

@ -11,12 +11,13 @@ from importlib.metadata import entry_points
from opentelemetry.trace import Span, Status, StatusCode from opentelemetry.trace import Span, Status, StatusCode
from .package_typing import ExportedNodes
from ..cmd.main_pre import tracer from ..cmd.main_pre import tracer
from .package_typing import ExportedNodes
from ..component_model.files import get_package_as_path from ..component_model.files import get_package_as_path
_comfy_nodes: ExportedNodes = ExportedNodes() _nodes_available_at_startup: ExportedNodes = ExportedNodes()
logger = logging.getLogger(__name__)
def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleType): def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleType):
node_class_mappings = getattr(module, 'NODE_CLASS_MAPPINGS', None) node_class_mappings = getattr(module, 'NODE_CLASS_MAPPINGS', None)
@ -55,7 +56,7 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
timings.append((time.perf_counter() - time_before, full_name, True, exported_nodes)) timings.append((time.perf_counter() - time_before, full_name, True, exported_nodes))
except Exception as exc: except Exception as exc:
module_decl = None module_decl = None
logging.error(f"{full_name} import failed", exc_info=exc) logger.error(f"{full_name} import failed", exc_info=exc)
span.set_status(Status(StatusCode.ERROR)) span.set_status(Status(StatusCode.ERROR))
span.record_exception(exc) span.record_exception(exc)
exceptions.append(exc) exceptions.append(exc)
@ -84,7 +85,7 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
potential_path_error: AttributeError = x potential_path_error: AttributeError = x
if potential_path_error.name == '__path__': if potential_path_error.name == '__path__':
continue continue
logging.error(f"{full_name} import failed", exc_info=x) logger.error(f"{full_name} import failed", exc_info=x)
success = False success = False
exceptions.append(x) exceptions.append(x)
span.set_status(Status(StatusCode.ERROR)) span.set_status(Status(StatusCode.ERROR))
@ -93,7 +94,7 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
if print_import_times and len(timings) > 0 or any(not success for (_, _, success, _) in timings): if print_import_times and len(timings) > 0 or any(not success for (_, _, success, _) in timings):
for (duration, module_name, success, new_nodes) in sorted(timings): for (duration, module_name, success, new_nodes) in sorted(timings):
logging.log(logging.DEBUG if success else logging.ERROR, f"{duration:6.1f} seconds{'' if success else ' (IMPORT FAILED)'}, {module_name} ({len(new_nodes)} nodes loaded)") logger.log(logging.DEBUG if success else logging.ERROR, f"{duration:6.1f} seconds{'' if success else ' (IMPORT FAILED)'}, {module_name} ({len(new_nodes)} nodes loaded)")
if raise_on_failure and len(exceptions) > 0: if raise_on_failure and len(exceptions) > 0:
try: try:
raise ExceptionGroup("Node import failed", exceptions) # pylint: disable=using-exception-groups-in-unsupported-version raise ExceptionGroup("Node import failed", exceptions) # pylint: disable=using-exception-groups-in-unsupported-version
@ -105,12 +106,16 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
@tracer.start_as_current_span("Import All Nodes In Workspace") @tracer.start_as_current_span("Import All Nodes In Workspace")
def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=False) -> ExportedNodes: def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=False) -> ExportedNodes:
# now actually import the nodes, to improve control of node loading order # now actually import the nodes, to improve control of node loading order
from comfy_extras import nodes as comfy_extras_nodes # pylint: disable=absolute-import-used
from ..cli_args import args from ..cli_args import args
from . import base_nodes
from .vanilla_node_importing import mitigated_import_of_vanilla_custom_nodes
# only load these nodes once # only load these nodes once
if len(_comfy_nodes) == 0: if len(_nodes_available_at_startup) == 0:
# import base_nodes first
from . import base_nodes
from comfy_extras import nodes as comfy_extras_nodes # pylint: disable=absolute-import-used
from .vanilla_node_importing import mitigated_import_of_vanilla_custom_nodes
base_and_extra = reduce(lambda x, y: x.update(y), base_and_extra = reduce(lambda x, y: x.update(y),
map(lambda module_inner: _import_and_enumerate_nodes_in_module(module_inner, raise_on_failure=raise_on_failure), [ map(lambda module_inner: _import_and_enumerate_nodes_in_module(module_inner, raise_on_failure=raise_on_failure), [
# this is the list of default nodes to import # this is the list of default nodes to import
@ -121,9 +126,9 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=Fa
custom_nodes_mappings = ExportedNodes() custom_nodes_mappings = ExportedNodes()
if args.disable_all_custom_nodes: if args.disable_all_custom_nodes:
logging.info("Loading custom nodes was disabled, only base and extra nodes were loaded") logger.info("Loading custom nodes was disabled, only base and extra nodes were loaded")
_comfy_nodes.update(base_and_extra) _nodes_available_at_startup.update(base_and_extra)
return _comfy_nodes return _nodes_available_at_startup
# load from entrypoints # load from entrypoints
for entry_point in entry_points().select(group='comfyui.custom_nodes'): for entry_point in entry_points().select(group='comfyui.custom_nodes'):
@ -131,7 +136,7 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=Fa
try: try:
module = entry_point.load() module = entry_point.load()
except ModuleNotFoundError as module_not_found_error: except ModuleNotFoundError as module_not_found_error:
logging.error(f"A module was not found while importing nodes via an entry point: {entry_point}. Please ensure the entry point in setup.py is named correctly", exc_info=module_not_found_error) logger.error(f"A module was not found while importing nodes via an entry point: {entry_point}. Please ensure the entry point in setup.py is named correctly", exc_info=module_not_found_error)
continue continue
# Ensure that what we've loaded is indeed a module # Ensure that what we've loaded is indeed a module
@ -146,5 +151,5 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=Fa
# don't allow custom nodes to overwrite base nodes # don't allow custom nodes to overwrite base nodes
custom_nodes_mappings -= base_and_extra custom_nodes_mappings -= base_and_extra
_comfy_nodes.update(base_and_extra + custom_nodes_mappings) _nodes_available_at_startup.update(base_and_extra + custom_nodes_mappings)
return _comfy_nodes return _nodes_available_at_startup

View File

@ -7,8 +7,6 @@ from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, T
from typing_extensions import TypedDict, NotRequired from typing_extensions import TypedDict, NotRequired
from comfy.comfy_types import FileLocator
T = TypeVar('T') T = TypeVar('T')

View File

@ -131,6 +131,7 @@ dev = [
"freezegun", "freezegun",
"coverage", "coverage",
"pylint", "pylint",
"astroid",
] ]
[project.optional-dependencies] [project.optional-dependencies]

View File

@ -24,18 +24,46 @@ class AbsoluteImportChecker(BaseChecker):
super().__init__(linter) super().__init__(linter)
def visit_importfrom(self, node: nodes.ImportFrom) -> None: def visit_importfrom(self, node: nodes.ImportFrom) -> None:
current_file = node.root().file """
if current_file is None: Check for absolute imports from the same top-level package.
This method is called for every `from ... import ...` statement.
It checks if a module within 'comfy' or 'comfy_extras' packages
is using an absolute import from its own package, which should
be a relative import instead.
For example, inside `comfy/nodes/logic.py`, an import like
`from comfy.utils import some_function` will be flagged.
The preferred way would be `from ..utils import some_function`.
"""
# An import is relative if its level is greater than 0.
# e.g., from . import foo (level=1), from .. import bar (level=2)
# We only want to check absolute imports, so we skip relative ones.
if node.level and node.level > 0:
return return
package_path = os.path.dirname(current_file) # Get the fully qualified name of the module being linted.
package_name = os.path.basename(package_path) # For a file at '.../comfy/nodes/common.py', this will be 'comfy.nodes.common'.
module_qname = node.root().qname()
if node.modname.startswith(package_name) and package_name in ['comfy', 'comfy_extras']: # `node.modname` is the module name in the `from` statement.
import_parts = node.modname.split('.') # For `from comfy.utils import x`, `modname` is `comfy.utils`.
imported_modname = node.modname
if not imported_modname:
return
if import_parts[0] == package_name: # We are only interested in modules within 'comfy' or 'comfy_extras'.
self.add_message('absolute-import-used', node=node, args=(node.modname,)) # We determine this by looking at the first part of the qualified name.
current_top_package = module_qname.split('.')[0]
if current_top_package not in ['comfy', 'comfy_extras']:
return
imported_top_package = imported_modname.split('.')[0]
# If the top-level package of the imported module is the same as the
# current module's top-level package, it's an internal absolute import.
if imported_top_package == current_top_package:
self.add_message('absolute-import-used', node=node, args=(imported_modname,))
def register(linter: "PyLinter") -> None: def register(linter: "PyLinter") -> None:

View File

@ -1,50 +1,25 @@
import sys
import time
import logging import logging
import multiprocessing import multiprocessing
import os import os
import pathlib import pathlib
import subprocess
import sys
import time
import urllib
from typing import Tuple, List
import pytest import pytest
import requests import requests
import socket
import subprocess
import urllib
from testcontainers.rabbitmq import RabbitMqContainer
from typing import Tuple, List
from comfy.cli_args_types import Configuration from comfy.cli_args_types import Configuration
logging.getLogger("pika").setLevel(logging.CRITICAL + 1) logging.getLogger("pika").setLevel(logging.CRITICAL + 1)
logging.getLogger("aio_pika").setLevel(logging.CRITICAL + 1) logging.getLogger("aio_pika").setLevel(logging.CRITICAL + 1)
logging.getLogger("testcontainers.core.container").setLevel(logging.WARNING)
logging.getLogger("testcontainers.core.waiting_utils").setLevel(logging.WARNING)
# fixes issues with running the testcontainers rabbitmqcontainer on Windows # fixes issues with running the testcontainers rabbitmqcontainer on Windows
os.environ["TC_HOST"] = "localhost" os.environ["TC_HOST"] = "localhost"
def get_lan_ip():
"""
Finds the host's IP address on the LAN it's connected to.
Returns:
str: The IP address of the host on the LAN.
"""
# Create a dummy socket
s = None
try:
# Connect to a dummy address (Here, Google's public DNS server)
# The actual connection is not made, but this allows finding out the LAN IP
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80))
ip = s.getsockname()[0]
finally:
if s is not None:
s.close()
return ip
def run_server(server_arguments: Configuration): def run_server(server_arguments: Configuration):
from comfy.cmd.main import main from comfy.cmd.main import main
from comfy.cli_args import args from comfy.cli_args import args
@ -95,6 +70,11 @@ def has_gpu() -> bool:
@pytest.fixture(scope="module", autouse=False, params=["ThreadPoolExecutor", "ProcessPoolExecutor"]) @pytest.fixture(scope="module", autouse=False, params=["ThreadPoolExecutor", "ProcessPoolExecutor"])
def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers: int = 1): def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers: int = 1):
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from testcontainers.rabbitmq import RabbitMqContainer
logging.getLogger("testcontainers.core.container").setLevel(logging.WARNING)
logging.getLogger("testcontainers.core.waiting_utils").setLevel(logging.WARNING)
hf_hub_download("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors") hf_hub_download("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors")
hf_hub_download("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors") hf_hub_download("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors")
@ -108,8 +88,8 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers
frontend_command = [ frontend_command = [
"comfyui", "comfyui",
"--listen=0.0.0.0", "--listen=127.0.0.1",
"--port=9001", "--port=19001",
"--cpu", "--cpu",
"--distributed-queue-frontend", "--distributed-queue-frontend",
f"-w={str(tmp_path)}", f"-w={str(tmp_path)}",
@ -122,7 +102,7 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers
for i in range(num_workers): for i in range(num_workers):
backend_command = [ backend_command = [
"comfyui-worker", "comfyui-worker",
f"--port={9002 + i}", f"--port={19002 + i}",
f"-w={str(tmp_path)}", f"-w={str(tmp_path)}",
f"--distributed-queue-connection-uri={connection_uri}", f"--distributed-queue-connection-uri={connection_uri}",
f"--executor-factory={executor_factory}" f"--executor-factory={executor_factory}"
@ -130,7 +110,7 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr)) processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr))
try: try:
server_address = f"http://{get_lan_ip()}:9001" server_address = f"http://127.0.0.1:19001"
start_time = time.time() start_time = time.time()
connected = False connected = False
while time.time() - start_time < 60: while time.time() - start_time < 60:

View File

@ -15,7 +15,7 @@ from comfy.component_model.executor_types import SendSyncEvent, SendSyncData, Ex
DependencyCycleError DependencyCycleError
from comfy.distributed.server_stub import ServerStub from comfy.distributed.server_stub import ServerStub
from comfy.execution_context import context_add_custom_nodes from comfy.execution_context import context_add_custom_nodes
from comfy.graph_utils import GraphBuilder, Node from comfy_execution.graph_utils import GraphBuilder, Node
from comfy.nodes.package_typing import ExportedNodes from comfy.nodes.package_typing import ExportedNodes
current_test_name = ContextVar('current_test_name', default=None) current_test_name = ContextVar('current_test_name', default=None)

View File

@ -1,4 +1,4 @@
from comfy.graph_utils import GraphBuilder, is_link from comfy_execution.graph_utils import GraphBuilder, is_link
from comfy.graph import ExecutionBlocker from comfy.graph import ExecutionBlocker
from .tools import VariantSupport from .tools import VariantSupport

View File

@ -3,7 +3,7 @@ import time
import asyncio import asyncio
from comfy.utils import ProgressBar from comfy.utils import ProgressBar
from .tools import VariantSupport from .tools import VariantSupport
from comfy.graph_utils import GraphBuilder from comfy_execution.graph_utils import GraphBuilder
from comfy.comfy_types.node_typing import ComfyNodeABC from comfy.comfy_types.node_typing import ComfyNodeABC
from comfy.comfy_types import IO from comfy.comfy_types import IO

View File

@ -1,4 +1,4 @@
from comfy.graph_utils import GraphBuilder from comfy_execution.graph_utils import GraphBuilder
from .tools import VariantSupport from .tools import VariantSupport
@VariantSupport() @VariantSupport()

View File

@ -0,0 +1,129 @@
"""Pylint checker for ensuring main_pre is imported first."""
from typing import TYPE_CHECKING, Optional, Union
import astroid
from astroid import nodes
from pylint.checkers import BaseChecker
if TYPE_CHECKING:
from pylint.lint import PyLinter
class MainPreImportOrderChecker(BaseChecker):
"""
Ensures that imports from 'comfy.cmd.main_pre' or similar setup modules
occur before any other relative imports or imports from the 'comfy' package.
This is important for code that relies on setup being performed by 'main_pre'
before other modules from the same package are imported.
"""
name = 'main-pre-import-order'
msgs = {
'W0002': (
'Setup import %s must be placed before other package imports like %s.',
'main-pre-import-not-first',
"To ensure necessary setup is performed, 'comfy.cmd.main_pre' or similar "
"setup imports must precede other relative imports or imports from the "
"'comfy' family of packages."
),
}
def _is_main_pre_import(self, stmt: Union[nodes.Import, nodes.ImportFrom]) -> bool:
"""Checks if an import statement is for 'comfy.cmd.main_pre'."""
if isinstance(stmt, nodes.Import):
for name, _ in stmt.names:
if name == 'comfy.cmd.main_pre' or name.startswith('comfy.cmd.main_pre.'):
return True
return False
if isinstance(stmt, nodes.ImportFrom):
qname: Optional[str] = None
if stmt.level == 0:
qname = stmt.modname
else:
try:
# Attempt to resolve the relative import to a fully qualified name
imported_module = stmt.do_import_module()
qname = imported_module.qname()
except astroid.AstroidError:
# Fallback for unresolved relative imports, check the literal module name
if stmt.modname and stmt.modname.endswith('.main_pre'):
return True
# Heuristic for `from ..cmd import main_pre` in `comfy/entrypoints/*`
if stmt.modname == 'cmd' and stmt.root().qname().startswith('comfy'):
for name, _ in stmt.names:
if name == 'main_pre':
return True
if not qname:
return False
# from comfy.cmd import main_pre
if qname == 'comfy.cmd':
for name, _ in stmt.names:
if name == 'main_pre':
return True
# from comfy.cmd.main_pre import ... OR from a.b.c.main_pre import ...
if qname == 'comfy.cmd.main_pre' or qname.endswith('.main_pre'):
return True
return False
def _is_other_relevant_import(self, stmt: Union[nodes.Import, nodes.ImportFrom]) -> bool:
"""
Checks if an import is a relative import or an import from
the 'comfy' package family, and is not a 'main_pre' import.
"""
if self._is_main_pre_import(stmt):
return False
if isinstance(stmt, nodes.ImportFrom):
if stmt.level and stmt.level > 0: # Any relative import
return True
if stmt.modname and stmt.modname.startswith('comfy'):
return True
if isinstance(stmt, nodes.Import):
for name, _ in stmt.names:
if name.startswith('comfy'):
return True
return False
def visit_module(self, node: nodes.Module) -> None:
"""Checks the order of imports within a module."""
imports = [
stmt for stmt in node.body
if isinstance(stmt, (nodes.Import, nodes.ImportFrom))
]
main_pre_import_node = None
for stmt in imports:
if self._is_main_pre_import(stmt):
main_pre_import_node = stmt
break
# If there's no main_pre import, there's nothing to check.
if not main_pre_import_node:
return
for stmt in imports:
# We only care about imports that appear before the main_pre import
if stmt.fromlineno >= main_pre_import_node.fromlineno:
break
if self._is_other_relevant_import(stmt):
self.add_message(
'main-pre-import-not-first',
node=main_pre_import_node,
args=(main_pre_import_node.as_string(), stmt.as_string())
)
return # Report once per file and exit to avoid spam
def register(linter: "PyLinter") -> None:
"""This function is required for a Pylint plugin."""
linter.register_checker(MainPreImportOrderChecker(linter))

View File

@ -0,0 +1,292 @@
import os
from unittest.mock import patch
import pytest
from aiohttp import web
from comfy.app.user_manager import UserManager
pytestmark = (
pytest.mark.asyncio
) # This applies the asyncio mark to all test functions in the module
@pytest.fixture
def user_manager(tmp_path):
um = UserManager()
um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join(
tmp_path, file
) if file else tmp_path
return um
@pytest.fixture
def app(user_manager):
app = web.Application()
routes = web.RouteTableDef()
user_manager.add_routes(routes)
app.add_routes(routes)
return app
async def test_listuserdata_empty_directory(aiohttp_client, app, tmp_path):
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir")
assert resp.status == 404
async def test_listuserdata_with_files(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir")
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir")
assert resp.status == 200
assert await resp.json() == ["file1.txt"]
async def test_listuserdata_recursive(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir" / "subdir")
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
f.write("test content")
with open(tmp_path / "test_dir" / "subdir" / "file2.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&recurse=true")
assert resp.status == 200
assert set(await resp.json()) == {"file1.txt", "subdir/file2.txt"}
async def test_listuserdata_full_info(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir")
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&full_info=true")
assert resp.status == 200
result = await resp.json()
assert len(result) == 1
assert result[0]["path"] == "file1.txt"
assert "size" in result[0]
assert "modified" in result[0]
async def test_listuserdata_split_path(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir" / "subdir")
with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&recurse=true&split=true")
assert resp.status == 200
assert await resp.json() == [["subdir/file1.txt", "subdir", "file1.txt"]]
async def test_listuserdata_invalid_directory(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=")
assert resp.status == 400
async def test_listuserdata_normalized_separator(aiohttp_client, app, tmp_path):
os_sep = "\\"
with patch("os.sep", os_sep):
with patch("os.path.sep", os_sep):
os.makedirs(tmp_path / "test_dir" / "subdir")
with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&recurse=true")
assert resp.status == 200
result = await resp.json()
assert len(result) == 1
assert "/" in result[0] # Ensure forward slash is used
assert "\\" not in result[0] # Ensure backslash is not present
assert result[0] == "subdir/file1.txt"
# Test with full_info
resp = await client.get(
"/userdata?dir=test_dir&recurse=true&full_info=true"
)
assert resp.status == 200
result = await resp.json()
assert len(result) == 1
assert "/" in result[0]["path"] # Ensure forward slash is used
assert "\\" not in result[0]["path"] # Ensure backslash is not present
assert result[0]["path"] == "subdir/file1.txt"
async def test_post_userdata_new_file(aiohttp_client, app, tmp_path):
client = await aiohttp_client(app)
content = b"test content"
resp = await client.post("/userdata/test.txt", data=content)
assert resp.status == 200
assert await resp.text() == '"test.txt"'
# Verify file was created with correct content
with open(tmp_path / "test.txt", "rb") as f:
assert f.read() == content
async def test_post_userdata_overwrite_existing(aiohttp_client, app, tmp_path):
# Create initial file
with open(tmp_path / "test.txt", "w") as f:
f.write("initial content")
client = await aiohttp_client(app)
new_content = b"updated content"
resp = await client.post("/userdata/test.txt", data=new_content)
assert resp.status == 200
assert await resp.text() == '"test.txt"'
# Verify file was overwritten
with open(tmp_path / "test.txt", "rb") as f:
assert f.read() == new_content
async def test_post_userdata_no_overwrite(aiohttp_client, app, tmp_path):
# Create initial file
with open(tmp_path / "test.txt", "w") as f:
f.write("initial content")
client = await aiohttp_client(app)
resp = await client.post("/userdata/test.txt?overwrite=false", data=b"new content")
assert resp.status == 409
# Verify original content unchanged
with open(tmp_path / "test.txt", "r") as f:
assert f.read() == "initial content"
async def test_post_userdata_full_info(aiohttp_client, app, tmp_path):
client = await aiohttp_client(app)
content = b"test content"
resp = await client.post("/userdata/test.txt?full_info=true", data=content)
assert resp.status == 200
result = await resp.json()
assert result["path"] == "test.txt"
assert result["size"] == len(content)
assert "modified" in result
async def test_move_userdata(aiohttp_client, app, tmp_path):
# Create initial file
with open(tmp_path / "source.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.post("/userdata/source.txt/move/dest.txt")
assert resp.status == 200
assert await resp.text() == '"dest.txt"'
# Verify file was moved
assert not os.path.exists(tmp_path / "source.txt")
with open(tmp_path / "dest.txt", "r") as f:
assert f.read() == "test content"
async def test_move_userdata_no_overwrite(aiohttp_client, app, tmp_path):
# Create source and destination files
with open(tmp_path / "source.txt", "w") as f:
f.write("source content")
with open(tmp_path / "dest.txt", "w") as f:
f.write("destination content")
client = await aiohttp_client(app)
resp = await client.post("/userdata/source.txt/move/dest.txt?overwrite=false")
assert resp.status == 409
# Verify files remain unchanged
with open(tmp_path / "source.txt", "r") as f:
assert f.read() == "source content"
with open(tmp_path / "dest.txt", "r") as f:
assert f.read() == "destination content"
async def test_move_userdata_full_info(aiohttp_client, app, tmp_path):
# Create initial file
with open(tmp_path / "source.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.post("/userdata/source.txt/move/dest.txt?full_info=true")
assert resp.status == 200
result = await resp.json()
assert result["path"] == "dest.txt"
assert result["size"] == len("test content")
assert "modified" in result
# Verify file was moved
assert not os.path.exists(tmp_path / "source.txt")
with open(tmp_path / "dest.txt", "r") as f:
assert f.read() == "test content"
async def test_listuserdata_v2_empty_root(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.get("/v2/userdata")
assert resp.status == 200
assert await resp.json() == []
async def test_listuserdata_v2_nonexistent_subdirectory(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.get("/v2/userdata?path=does_not_exist")
assert resp.status == 404
async def test_listuserdata_v2_default(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir" / "subdir")
(tmp_path / "test_dir" / "file1.txt").write_text("content")
(tmp_path / "test_dir" / "subdir" / "file2.txt").write_text("content")
client = await aiohttp_client(app)
resp = await client.get("/v2/userdata?path=test_dir")
assert resp.status == 200
data = await resp.json()
file_paths = {item["path"] for item in data if item["type"] == "file"}
assert file_paths == {"test_dir/file1.txt", "test_dir/subdir/file2.txt"}
async def test_listuserdata_v2_normalized_separators(aiohttp_client, app, tmp_path, monkeypatch):
# Force backslash as os separator
monkeypatch.setattr(os, 'sep', '\\')
monkeypatch.setattr(os.path, 'sep', '\\')
os.makedirs(tmp_path / "test_dir" / "subdir")
(tmp_path / "test_dir" / "subdir" / "file1.txt").write_text("x")
client = await aiohttp_client(app)
resp = await client.get("/v2/userdata?path=test_dir")
assert resp.status == 200
data = await resp.json()
for item in data:
assert "/" in item["path"]
assert "\\" not in item["path"]
async def test_listuserdata_v2_url_encoded_path(aiohttp_client, app, tmp_path):
# Create a directory with a space in its name and a file inside
os.makedirs(tmp_path / "my dir")
(tmp_path / "my dir" / "file.txt").write_text("content")
client = await aiohttp_client(app)
# Use URL-encoded space in path parameter
resp = await client.get("/v2/userdata?path=my%20dir&recurse=false")
assert resp.status == 200
data = await resp.json()
assert len(data) == 1
entry = data[0]
assert entry["name"] == "file.txt"
# Ensure the path is correctly decoded and uses forward slash
assert entry["path"] == "my dir/file.txt"

View File

@ -122,7 +122,7 @@ def test_bool_request_parameter():
assert v == True assert v == True
def test_string_enum_request_parameter(): async def test_string_enum_request_parameter():
nt = StringEnumRequestParameter.INPUT_TYPES() nt = StringEnumRequestParameter.INPUT_TYPES()
assert nt is not None assert nt is not None
n = StringEnumRequestParameter() n = StringEnumRequestParameter()
@ -155,8 +155,9 @@ def test_string_enum_request_parameter():
} }
from comfy.cmd.execution import validate_inputs from comfy.cmd.execution import validate_inputs
validated: dict[str, ValidateInputsTuple] = {} validated: dict[str, ValidateInputsTuple] = {}
validated["1"] = validate_inputs(prompt, "1", validated) prompt_id = str(uuid.uuid4())
validated["2"] = validate_inputs(prompt, "2", validated) validated["1"] = await validate_inputs(prompt_id, prompt, "1", validated)
validated["2"] = await validate_inputs(prompt_id, prompt, "2", validated)
assert validated["2"].valid assert validated["2"].valid