diff --git a/.pylintrc b/.pylintrc index a095aa977..3accdb6d6 100644 --- a/.pylintrc +++ b/.pylintrc @@ -82,7 +82,7 @@ limit-inference-results=100 # List of plugins (as comma separated values of python module names) to load, # usually to register additional checkers. -load-plugins=tests.absolute_import_checker +load-plugins=tests.absolute_import_checker,tests.main_pre_import_checker # Pickle collected data for later comparisons. persistent=yes @@ -678,7 +678,7 @@ disable=raw-checker-failed, # either give multiple identifier separated by comma (,) or put this option # multiple time (only on the command line, not in the configuration file where # it should appear only once). See also the "--disable" option for examples. -enable=absolute-import-used +enable= [METHOD_ARGS] diff --git a/alembic.ini b/comfy/alembic.ini similarity index 98% rename from alembic.ini rename to comfy/alembic.ini index 12f18712f..bb2dafd20 100644 --- a/alembic.ini +++ b/comfy/alembic.ini @@ -13,7 +13,7 @@ script_location = alembic_db # sys.path path, will be prepended to sys.path if present. # defaults to the current working directory. -prepend_sys_path = . +# prepend_sys_path = . # timezone to use when rendering the date within the migration file # as well as the filename. @@ -63,7 +63,7 @@ version_path_separator = os # are written from script.py.mako # output_encoding = utf-8 -sqlalchemy.url = sqlite:///user/comfyui.db +sqlalchemy.url = sqlite:///./comfyui.db [post_write_hooks] diff --git a/alembic_db/README.md b/comfy/alembic_db/README.md similarity index 100% rename from alembic_db/README.md rename to comfy/alembic_db/README.md diff --git a/comfy/alembic_db/__init__.py b/comfy/alembic_db/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/alembic_db/env.py b/comfy/alembic_db/env.py similarity index 96% rename from alembic_db/env.py rename to comfy/alembic_db/env.py index 73d51eccf..90096f7ad 100644 --- a/alembic_db/env.py +++ b/comfy/alembic_db/env.py @@ -1,3 +1,4 @@ +# pylint: disable=no-member from sqlalchemy import engine_from_config from sqlalchemy import pool @@ -7,8 +8,7 @@ from alembic import context # access to the values within the .ini file in use. config = context.config - -from comfy.app.database.models import Base +from ..app.database.models import Base target_metadata = Base.metadata # other values from the config, defined by the needs of env.py, diff --git a/alembic_db/script.py.mako b/comfy/alembic_db/script.py.mako similarity index 100% rename from alembic_db/script.py.mako rename to comfy/alembic_db/script.py.mako diff --git a/comfy/app/database/db.py b/comfy/app/database/db.py index 3089e3ba0..c995f061e 100644 --- a/comfy/app/database/db.py +++ b/comfy/app/database/db.py @@ -1,8 +1,10 @@ import logging import os import shutil +from importlib.resources import files from ...cli_args import args +from ...component_model.files import get_package_as_path Session = None @@ -15,6 +17,7 @@ from sqlalchemy.orm import sessionmaker _DB_AVAILABLE = True +logger = logging.getLogger(__name__) def dependencies_available(): """ @@ -32,9 +35,8 @@ def can_create_session(): def get_alembic_config(): - root_path = os.path.join(os.path.dirname(__file__), "../..") - config_path = os.path.abspath(os.path.join(root_path, "alembic.ini")) - scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db")) + config_path = str(files("comfy") / "alembic.ini") + scripts_path = get_package_as_path("comfy.alembic_db") config = Config(config_path) config.set_main_option("script_location", scripts_path) @@ -53,7 +55,7 @@ def get_db_path(): def init_db(): db_url = args.database_url - logging.debug(f"Database URL: {db_url}") + logger.debug(f"Database URL: {db_url}") db_path = get_db_path() db_exists = os.path.exists(db_path) @@ -70,7 +72,7 @@ def init_db(): target_rev = script.get_current_head() if target_rev is None: - logging.warning("No target revision found.") + logger.debug("No target revision found.") elif current_rev != target_rev: # Backup the database pre upgrade backup_path = db_path + ".bkp" @@ -81,13 +83,13 @@ def init_db(): try: command.upgrade(config, target_rev) - logging.info(f"Database upgraded from {current_rev} to {target_rev}") + logger.info(f"Database upgraded from {current_rev} to {target_rev}") except Exception as e: if backup_path: # Restore the database from backup if upgrade fails shutil.copy(backup_path, db_path) os.remove(backup_path) - logging.exception("Error upgrading database: ") + logger.exception("Error upgrading database: ") raise e global Session diff --git a/comfy/cldm/dit_embedder.py b/comfy/cldm/dit_embedder.py index f9bf31012..e1f38b81c 100644 --- a/comfy/cldm/dit_embedder.py +++ b/comfy/cldm/dit_embedder.py @@ -5,7 +5,8 @@ import torch import torch.nn as nn from torch import Tensor -from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch +from ..ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, \ + get_2d_sincos_pos_embed_torch class ControlNetEmbedder(nn.Module): diff --git a/comfy/cli_args.py b/comfy/cli_args.py index d1b6af7c9..b289d7a17 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -11,7 +11,7 @@ import configargparse as argparse from . import __version__ from . import options from .cli_args_types import LatentPreviewMethod, Configuration, ConfigurationExtender, EnumAction, \ - EnhancedConfigArgParser, PerformanceFeature, is_valid_directory + EnhancedConfigArgParser, PerformanceFeature, is_valid_directory, db_config # todo: move this DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest" @@ -261,8 +261,8 @@ def _create_parser() -> EnhancedConfigArgParser: help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)", ) - parser.add_argument("--database-url", type=str, default=f"sqlite:///:memory:", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") - + default_db_url = db_config() + parser.add_argument("--database-url", type=str, default=default_db_url, help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") parser.add_argument("--workflows", type=str, nargs='+', default=[], help="Execute the API workflow(s) specified in the provided files. For each workflow, its outputs will be printed to a line to standard out. Application logging will be redirected to standard error. Use `-` to signify standard in.") # now give plugins a chance to add configuration diff --git a/comfy/cli_args_types.py b/comfy/cli_args_types.py index 1e9ed1b10..1c0148bd3 100644 --- a/comfy/cli_args_types.py +++ b/comfy/cli_args_types.py @@ -1,6 +1,7 @@ from __future__ import annotations import enum +import logging import os from typing import Optional, List, Callable, Any, Union, Mapping, NamedTuple @@ -20,6 +21,22 @@ class LatentPreviewMethod(enum.Enum): ConfigObserver = Callable[[str, Any], None] +def db_config() -> str: + from .vendor.appdirs import user_data_dir + + logger = logging.getLogger(__name__) + try: + data_dir = user_data_dir(appname="comfyui") + os.makedirs(data_dir, exist_ok=True) + db_path = os.path.join(data_dir, "comfy.db") + default_db_url = f"sqlite:///{db_path}" + except Exception as e: + # Fallback to an in-memory database if the user directory can't be accessed + logger.warning(f"Could not determine user data directory for database, falling back to in-memory: {e}") + default_db_url = "sqlite:///:memory:" + return default_db_url + + def is_valid_directory(path: str) -> str: """Validate if the given path is a directory, and check permissions.""" if not os.path.exists(path): @@ -261,7 +278,7 @@ class Configuration(dict): self.front_end_version: str = "comfyanonymous/ComfyUI@latest" self.front_end_root: Optional[str] = None self.comfy_api_base: str = "https://api.comfy.org" - self.database_url: str = "sqlite:///:memory:" + self.database_url: str = db_config() for key, value in kwargs.items(): self[key] = value diff --git a/comfy/client/embedded_comfy_client.py b/comfy/client/embedded_comfy_client.py index d6aaaf33f..98edaa0ad 100644 --- a/comfy/client/embedded_comfy_client.py +++ b/comfy/client/embedded_comfy_client.py @@ -14,11 +14,11 @@ from opentelemetry import context, propagate from opentelemetry.context import Context, attach, detach from opentelemetry.trace import Status, StatusCode +from ..cmd.main_pre import tracer from .client_types import V1QueuePromptResponse from ..api.components.schema.prompt import PromptDict from ..cli_args_types import Configuration from ..cmd.folder_paths import init_default_paths # pylint: disable=import-error -from ..cmd.main_pre import tracer from ..component_model.executor_types import ExecutorToClientProgress from ..component_model.make_mutable import make_mutable from ..distributed.executors import ContextVarExecutor @@ -97,7 +97,7 @@ async def __execute_prompt( else: prompt_executor.server = progress_handler - prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id}, + await prompt_executor.execute_async(prompt_mut, prompt_id, {"client_id": client_id}, execute_outputs=validation_tuple.good_output_node_ids) return prompt_executor.outputs_ui except Exception as exc_info: @@ -195,7 +195,7 @@ class Comfy: if isinstance(prompt, str): prompt = json.loads(prompt) if isinstance(prompt, dict): - from comfy.api.components.schema.prompt import Prompt + from ..api.components.schema.prompt import Prompt prompt = Prompt.validate(prompt) outputs = await self.queue_prompt(prompt) return V1QueuePromptResponse(urls=[], outputs=outputs) diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index 21a8696f5..73f723921 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -19,14 +19,15 @@ from typing import List, Optional, Tuple, Literal import torch from opentelemetry.trace import get_current_span, StatusCode, Status +# order matters +from .main_pre import tracer + from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, \ DependencyAwareCache, \ BasicCache -# order matters from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker from comfy_execution.graph_utils import is_link, GraphBuilder from comfy_execution.utils import CurrentNodeContext -from .main_pre import tracer from .. import interruption from .. import model_management from ..cli_args import args @@ -37,7 +38,7 @@ from ..component_model.executor_types import ExecutorToClientProgress, Validatio HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage from ..component_model.files import canonicalize_path from ..component_model.module_property import create_module_properties -from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus +from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus, ExecutionStatusAsDict from ..execution_context import context_execute_node, context_execute_prompt from ..execution_ext import should_panic_on_exception from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode @@ -388,7 +389,7 @@ def format_value(x) -> FormattedValue: return str(x.__class__) -async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches, _node_id: str, extra_data: dict, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple: +async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches, node_id: str, extra_data: dict, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple: """ :param server: @@ -402,8 +403,8 @@ async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, ca :param pending_subgraph_results: :return: """ - with context_execute_node(_node_id): - return _execute(server, dynprompt, caches, _node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) + with context_execute_node(node_id): + return await _execute(server, dynprompt, caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple: @@ -516,7 +517,7 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra unblock() asyncio.create_task(await_completion()) - return (ExecutionResult.PENDING, None, None) + return RecursiveExecutionTuple(ExecutionResult.PENDING, None, None) if len(output_ui) > 0: caches.ui.set(unique_id, { "meta": { @@ -685,11 +686,6 @@ class PromptExecutor: if ex is not None and self.raise_exceptions: raise ex - def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): - asyncio_loop = asyncio.new_event_loop() - asyncio.set_event_loop(asyncio_loop) - asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) - async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): # torchao and potentially other optimization approaches break when the models are created in inference mode # todo: this should really be backpropagated to code which creates ModelPatchers via lazy evaluation rather than globally checked here @@ -1206,7 +1202,7 @@ class PromptQueue(AbstractPromptQueue): self.server.queue_updated() return copy.deepcopy(item_with_future.queue_tuple), task_id - def task_done(self, item_id: str, outputs: dict, + def task_done(self, item_id: str, outputs: HistoryResultDict, status: Optional[ExecutionStatus]): history_result = outputs with self.mutex: @@ -1215,9 +1211,9 @@ class PromptQueue(AbstractPromptQueue): if len(self.history) > MAXIMUM_HISTORY_SIZE: self.history.pop(next(iter(self.history))) - status_dict: Optional[dict] = None + status_dict = None if status is not None: - status_dict = copy.deepcopy(ExecutionStatus(*status)._asdict()) + status_dict: Optional[ExecutionStatusAsDict] = status.as_dict() outputs_ = history_result["outputs"] # Remove sensitive data from extra_data before storing in history @@ -1225,11 +1221,13 @@ class PromptQueue(AbstractPromptQueue): if sensitive_val in prompt[3]: prompt[3].pop(sensitive_val) - self.history[prompt[1]] = { + history_entry: HistoryEntry = { "prompt": prompt, "outputs": copy.deepcopy(outputs_), - 'status': status_dict, } + if status_dict is not None: + history_entry["status"] = status_dict + self.history[prompt[1]] = history_entry self.history[prompt[1]].update(history_result) self.server.queue_updated() if queue_item.completed: diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index 522ee1f52..bfd578cd8 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -10,20 +10,22 @@ import time from pathlib import Path from typing import Optional -from comfy.component_model.entrypoints_common import configure_application_paths, executor_from_args +# main_pre must be the earliest import +from .main_pre import args + from . import hook_breaker_ac10a0 from .extra_model_paths import load_extra_path_config -# main_pre must be the earliest import since it suppresses some spurious warnings -from .main_pre import args from .. import model_management from ..analytics.analytics import initialize_event_tracking from ..cmd import cuda_malloc from ..cmd import folder_paths from ..cmd import server as server_module from ..component_model.abstract_prompt_queue import AbstractPromptQueue +from ..component_model.entrypoints_common import configure_application_paths, executor_from_args from ..distributed.distributed_prompt_queue import DistributedPromptQueue from ..distributed.server_stub import ServerStub from ..nodes.package import import_all_nodes_in_workspace +from ..nodes_context import get_nodes logger = logging.getLogger(__name__) @@ -42,6 +44,10 @@ def cuda_malloc_warning(): def prompt_worker(q: AbstractPromptQueue, server_instance: server_module.PromptServer): + asyncio.run(_prompt_worker(q, server_instance)) + + +async def _prompt_worker(q: AbstractPromptQueue, server_instance: server_module.PromptServer): from ..cmd import execution from ..component_model import queue_types from .. import model_management @@ -68,7 +74,7 @@ def prompt_worker(q: AbstractPromptQueue, server_instance: server_module.PromptS prompt_id = item[1] server_instance.last_prompt_id = prompt_id - e.execute(item[2], prompt_id, item[3], item[4]) + await e.execute_async(item[2], prompt_id, item[3], item[4]) need_gc = True q.task_done(item_id, e.history_result, @@ -174,17 +180,16 @@ async def _start_comfyui(from_script_dir: Optional[Path] = None): for config_path in itertools.chain(*args.extra_model_paths_config): load_extra_path_config(config_path) - # always create directories when started interactively - folder_paths.create_directories() if args.create_directories: + # then, import and exit import_all_nodes_in_workspace(raise_on_failure=False) folder_paths.create_directories() exit(0) - - setup_database() + elif args.quick_test_for_ci: + import_all_nodes_in_workspace(raise_on_failure=True) + exit(0) if args.windows_standalone_build: - folder_paths.create_directories() try: from . import new_updater new_updater.update_windows_updater() @@ -198,7 +203,7 @@ async def _start_comfyui(from_script_dir: Optional[Path] = None): # at this stage, it's safe to import nodes hook_breaker_ac10a0.save_functions() - server.nodes = import_all_nodes_in_workspace() + server.nodes = get_nodes() hook_breaker_ac10a0.restore_functions() # as a side effect, this also populates the nodes for execution @@ -221,6 +226,7 @@ async def _start_comfyui(from_script_dir: Optional[Path] = None): server.add_routes() cuda_malloc_warning() + setup_database() # in a distributed setting, the default prompt worker will not be able to send execution events via the websocket worker_thread_server = server if not distributed else ServerStub() @@ -254,15 +260,8 @@ async def _start_comfyui(from_script_dir: Optional[Path] = None): logger.debug(f"Setting input directory to: {input_dir}") folder_paths.set_input_directory(input_dir) - if args.quick_test_for_ci: - # for CI purposes, try importing all the nodes - import_all_nodes_in_workspace(raise_on_failure=True) - return - else: - # we no longer lazily load nodes. we'll do it now for the sake of creating directories - import_all_nodes_in_workspace(raise_on_failure=False) - # now that nodes are loaded, create more directories if appropriate - folder_paths.create_directories() + # now that nodes are loaded, create directories + folder_paths.create_directories() if len(args.workflows) > 0: configure_application_paths(args) diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index 7f094163a..0985be3ba 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -1169,7 +1169,7 @@ class PromptServer(ExecutorToClientProgress): await runner.setup() if 'tls_keyfile' in args or 'tls_certfile' in args: - raise ValueError("Use caddy instead of aiohttp to serve https by setting up a reverse proxy. See README.md") + logger.warning("Use caddy instead of aiohttp to serve https by setting up a reverse proxy. See README.md") def is_ipv4(address: str, *args): try: diff --git a/comfy/component_model/abstract_prompt_queue.py b/comfy/component_model/abstract_prompt_queue.py index b9f8f7bfe..74cde5fcf 100644 --- a/comfy/component_model/abstract_prompt_queue.py +++ b/comfy/component_model/abstract_prompt_queue.py @@ -3,6 +3,7 @@ from __future__ import annotations import typing from abc import ABCMeta, abstractmethod +from .executor_types import HistoryResultDict from .queue_types import QueueTuple, HistoryEntry, QueueItem, Flags, ExecutionStatus, TaskInvocation, AbstractPromptQueueGetCurrentQueueItems @@ -42,7 +43,7 @@ class AbstractPromptQueue(metaclass=ABCMeta): pass @abstractmethod - def task_done(self, item_id: str, outputs: dict, + def task_done(self, item_id: str, outputs: HistoryResultDict, status: typing.Optional[ExecutionStatus]): """ Signals to the user interface that the task with the specified id is completed diff --git a/comfy/component_model/queue_types.py b/comfy/component_model/queue_types.py index 9b1471b84..21be0e5f2 100644 --- a/comfy/component_model/queue_types.py +++ b/comfy/component_model/queue_types.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import copy from enum import Enum from typing import NamedTuple, Optional, List, Literal, Sequence from typing import Tuple @@ -24,6 +25,13 @@ class ExecutionStatus(NamedTuple): completed: bool messages: List[str] + def as_dict(self) -> ExecutionStatusAsDict: + return { + "status_str": self.status_str, + "completed": self.completed, + "messages": copy.copy(self.messages), + } + class ExecutionError(RuntimeError): def __init__(self, task_id: int | str, status: Optional[ExecutionStatus] = None, exceptions: Optional[Sequence[Exception]] = None, *args): diff --git a/comfy/distributed/distributed_prompt_queue.py b/comfy/distributed/distributed_prompt_queue.py index 6947bb534..486657347 100644 --- a/comfy/distributed/distributed_prompt_queue.py +++ b/comfy/distributed/distributed_prompt_queue.py @@ -13,12 +13,12 @@ from aio_pika import connect_robust from aio_pika.abc import AbstractConnection, AbstractChannel from aio_pika.patterns import JsonRPC +from ..cmd.main_pre import tracer from .distributed_progress import ProgressHandlers from .distributed_types import RpcRequest, RpcReply from .history import History from .server_stub import ServerStub from ..auth.permissions import jwt_decode -from ..cmd.main_pre import tracer from ..cmd.server import PromptServer from ..component_model.abstract_prompt_queue import AsyncAbstractPromptQueue, AbstractPromptQueue from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData, HistoryResultDict diff --git a/comfy/distributed/distributed_prompt_worker.py b/comfy/distributed/distributed_prompt_worker.py index c7ad179fb..d831ba734 100644 --- a/comfy/distributed/distributed_prompt_worker.py +++ b/comfy/distributed/distributed_prompt_worker.py @@ -10,12 +10,12 @@ from aio_pika.patterns import JsonRPC from aiohttp import web from aiormq import AMQPConnectionError +from ..cmd.main_pre import tracer from .executors import ContextVarExecutor from .distributed_progress import DistributedExecutorToClientProgress from .distributed_types import RpcRequest, RpcReply from .process_pool_executor import ProcessPoolExecutor from ..client.embedded_comfy_client import Comfy -from ..cmd.main_pre import tracer from ..component_model.queue_types import ExecutionStatus logger = logging.getLogger(__name__) diff --git a/comfy/entrypoints/worker.py b/comfy/entrypoints/worker.py index 9e180b251..a0bfc0653 100644 --- a/comfy/entrypoints/worker.py +++ b/comfy/entrypoints/worker.py @@ -2,7 +2,6 @@ import asyncio from ..cmd.main_pre import args from ..component_model.entrypoints_common import configure_application_paths, executor_from_args -from ..distributed.executors import ContextVarExecutor, ContextVarProcessPoolExecutor async def main(): diff --git a/comfy/image_encoders/dino2.py b/comfy/image_encoders/dino2.py index 976f98c65..9a62551ca 100644 --- a/comfy/image_encoders/dino2.py +++ b/comfy/image_encoders/dino2.py @@ -1,7 +1,8 @@ import torch -from comfy.text_encoders.bert import BertAttention -import comfy.model_management -from comfy.ldm.modules.attention import optimized_attention_for_device + +from ..ldm.modules.attention import optimized_attention_for_device +from ..model_management import cast_to_device +from ..text_encoders.bert import BertAttention class Dino2AttentionOutput(torch.nn.Module): @@ -29,7 +30,7 @@ class LayerScale(torch.nn.Module): self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) def forward(self, x): - return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype) + return x * cast_to_device(self.lambda1, x.device, x.dtype) class SwiGLUFFN(torch.nn.Module): @@ -117,7 +118,7 @@ class Dino2Embeddings(torch.nn.Module): x = self.patch_embeddings(pixel_values) # TODO: mask_token? x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1) - x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype) + x = x + cast_to_device(self.position_embeddings, x.device, x.dtype) return x diff --git a/comfy/language/chat_templates.py b/comfy/language/chat_templates.py index 78838d064..c3e6834f8 100644 --- a/comfy/language/chat_templates.py +++ b/comfy/language/chat_templates.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging -from importlib.abc import Traversable # pylint: disable=no-name-in-module from importlib.resources import files from pathlib import Path @@ -10,7 +9,7 @@ KNOWN_CHAT_TEMPLATES = {} def _update_known_chat_templates(): try: - _chat_templates: Traversable = files(__package__) / "chat_templates" + _chat_templates = files(__package__) / "chat_templates" _extra_jinja_templates = {Path(traversable.name).stem: traversable.read_text().replace(' ', '').replace('\n', '') for traversable in _chat_templates.iterdir() if traversable.is_file()} KNOWN_CHAT_TEMPLATES.update(_extra_jinja_templates) except ImportError as exc: diff --git a/comfy/language/language_types.py b/comfy/language/language_types.py index 2a3601058..54b6a8484 100644 --- a/comfy/language/language_types.py +++ b/comfy/language/language_types.py @@ -10,7 +10,7 @@ from transformers.tokenization_utils_base import TextInput, PreTokenizedInput, T from transformers.utils import PaddingStrategy from typing_extensions import TypedDict, NotRequired -from comfy.component_model.tensor_types import RGBImageBatch +from ..component_model.tensor_types import RGBImageBatch class ProcessorResult(TypedDict): diff --git a/comfy/ldm/cosmos/blocks.py b/comfy/ldm/cosmos/blocks.py index 5c4356a3f..e0c74033c 100644 --- a/comfy/ldm/cosmos/blocks.py +++ b/comfy/ldm/cosmos/blocks.py @@ -23,7 +23,7 @@ from einops import rearrange, repeat from einops.layers.torch import Rearrange from torch import nn -from comfy.ldm.modules.attention import optimized_attention +from ..modules.attention import optimized_attention def get_normalization(name: str, channels: int, weight_args={}, operations=None): diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py index 316117f77..b52422780 100644 --- a/comfy/ldm/cosmos/predict2.py +++ b/comfy/ldm/cosmos/predict2.py @@ -11,7 +11,7 @@ import math from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis from torchvision import transforms -from comfy.ldm.modules.attention import optimized_attention +from ..modules.attention import optimized_attention def apply_rotary_pos_emb( t: torch.Tensor, diff --git a/comfy/ldm/hunyuan3d/model.py b/comfy/ldm/hunyuan3d/model.py index 39d049309..e75d321e8 100644 --- a/comfy/ldm/hunyuan3d/model.py +++ b/comfy/ldm/hunyuan3d/model.py @@ -1,12 +1,6 @@ import torch from torch import nn -from comfy.ldm.flux.layers import ( - DoubleStreamBlock, - LastLayer, - MLPEmbedder, - SingleStreamBlock, - timestep_embedding, -) +from ..flux.layers import DoubleStreamBlock, LastLayer, MLPEmbedder, SingleStreamBlock, timestep_embedding class Hunyuan3Dv2(nn.Module): diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index a8ebc5ec6..c805ac849 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -5,10 +5,10 @@ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from comfy.ldm.modules.diffusionmodules.model import vae_attention +from ..modules.diffusionmodules.model import vae_attention -import comfy.ops -ops = comfy.ops.disable_weight_init +from ...ops import disable_weight_init +ops = disable_weight_init CACHE_T = 2 diff --git a/comfy/model_management.py b/comfy/model_management.py index 2d379cc02..29f368b7d 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -31,9 +31,9 @@ import psutil import torch from opentelemetry.trace import get_current_span +from .cmd.main_pre import tracer from . import interruption from .cli_args import args, PerformanceFeature -from .cmd.main_pre import tracer from .component_model.deprecation import _deprecate_method from .model_management_types import ModelManageable diff --git a/comfy/nodes/package.py b/comfy/nodes/package.py index 89cb946aa..6c48d54a7 100644 --- a/comfy/nodes/package.py +++ b/comfy/nodes/package.py @@ -11,12 +11,13 @@ from importlib.metadata import entry_points from opentelemetry.trace import Span, Status, StatusCode -from .package_typing import ExportedNodes from ..cmd.main_pre import tracer +from .package_typing import ExportedNodes from ..component_model.files import get_package_as_path -_comfy_nodes: ExportedNodes = ExportedNodes() +_nodes_available_at_startup: ExportedNodes = ExportedNodes() +logger = logging.getLogger(__name__) def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleType): node_class_mappings = getattr(module, 'NODE_CLASS_MAPPINGS', None) @@ -55,7 +56,7 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType, timings.append((time.perf_counter() - time_before, full_name, True, exported_nodes)) except Exception as exc: module_decl = None - logging.error(f"{full_name} import failed", exc_info=exc) + logger.error(f"{full_name} import failed", exc_info=exc) span.set_status(Status(StatusCode.ERROR)) span.record_exception(exc) exceptions.append(exc) @@ -84,7 +85,7 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType, potential_path_error: AttributeError = x if potential_path_error.name == '__path__': continue - logging.error(f"{full_name} import failed", exc_info=x) + logger.error(f"{full_name} import failed", exc_info=x) success = False exceptions.append(x) span.set_status(Status(StatusCode.ERROR)) @@ -93,7 +94,7 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType, if print_import_times and len(timings) > 0 or any(not success for (_, _, success, _) in timings): for (duration, module_name, success, new_nodes) in sorted(timings): - logging.log(logging.DEBUG if success else logging.ERROR, f"{duration:6.1f} seconds{'' if success else ' (IMPORT FAILED)'}, {module_name} ({len(new_nodes)} nodes loaded)") + logger.log(logging.DEBUG if success else logging.ERROR, f"{duration:6.1f} seconds{'' if success else ' (IMPORT FAILED)'}, {module_name} ({len(new_nodes)} nodes loaded)") if raise_on_failure and len(exceptions) > 0: try: raise ExceptionGroup("Node import failed", exceptions) # pylint: disable=using-exception-groups-in-unsupported-version @@ -105,12 +106,16 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType, @tracer.start_as_current_span("Import All Nodes In Workspace") def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=False) -> ExportedNodes: # now actually import the nodes, to improve control of node loading order - from comfy_extras import nodes as comfy_extras_nodes # pylint: disable=absolute-import-used from ..cli_args import args - from . import base_nodes - from .vanilla_node_importing import mitigated_import_of_vanilla_custom_nodes + # only load these nodes once - if len(_comfy_nodes) == 0: + if len(_nodes_available_at_startup) == 0: + + # import base_nodes first + from . import base_nodes + from comfy_extras import nodes as comfy_extras_nodes # pylint: disable=absolute-import-used + from .vanilla_node_importing import mitigated_import_of_vanilla_custom_nodes + base_and_extra = reduce(lambda x, y: x.update(y), map(lambda module_inner: _import_and_enumerate_nodes_in_module(module_inner, raise_on_failure=raise_on_failure), [ # this is the list of default nodes to import @@ -121,9 +126,9 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=Fa custom_nodes_mappings = ExportedNodes() if args.disable_all_custom_nodes: - logging.info("Loading custom nodes was disabled, only base and extra nodes were loaded") - _comfy_nodes.update(base_and_extra) - return _comfy_nodes + logger.info("Loading custom nodes was disabled, only base and extra nodes were loaded") + _nodes_available_at_startup.update(base_and_extra) + return _nodes_available_at_startup # load from entrypoints for entry_point in entry_points().select(group='comfyui.custom_nodes'): @@ -131,7 +136,7 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=Fa try: module = entry_point.load() except ModuleNotFoundError as module_not_found_error: - logging.error(f"A module was not found while importing nodes via an entry point: {entry_point}. Please ensure the entry point in setup.py is named correctly", exc_info=module_not_found_error) + logger.error(f"A module was not found while importing nodes via an entry point: {entry_point}. Please ensure the entry point in setup.py is named correctly", exc_info=module_not_found_error) continue # Ensure that what we've loaded is indeed a module @@ -146,5 +151,5 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=Fa # don't allow custom nodes to overwrite base nodes custom_nodes_mappings -= base_and_extra - _comfy_nodes.update(base_and_extra + custom_nodes_mappings) - return _comfy_nodes + _nodes_available_at_startup.update(base_and_extra + custom_nodes_mappings) + return _nodes_available_at_startup diff --git a/comfy/nodes/package_typing.py b/comfy/nodes/package_typing.py index 2639b3945..911db4d69 100644 --- a/comfy/nodes/package_typing.py +++ b/comfy/nodes/package_typing.py @@ -7,8 +7,6 @@ from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, T from typing_extensions import TypedDict, NotRequired -from comfy.comfy_types import FileLocator - T = TypeVar('T') diff --git a/pyproject.toml b/pyproject.toml index c6d17e4a2..745415567 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -131,6 +131,7 @@ dev = [ "freezegun", "coverage", "pylint", + "astroid", ] [project.optional-dependencies] diff --git a/tests/absolute_import_checker.py b/tests/absolute_import_checker.py index 13bb644e6..fb1f2afaf 100644 --- a/tests/absolute_import_checker.py +++ b/tests/absolute_import_checker.py @@ -24,19 +24,47 @@ class AbsoluteImportChecker(BaseChecker): super().__init__(linter) def visit_importfrom(self, node: nodes.ImportFrom) -> None: - current_file = node.root().file - if current_file is None: + """ + Check for absolute imports from the same top-level package. + + This method is called for every `from ... import ...` statement. + It checks if a module within 'comfy' or 'comfy_extras' packages + is using an absolute import from its own package, which should + be a relative import instead. + + For example, inside `comfy/nodes/logic.py`, an import like + `from comfy.utils import some_function` will be flagged. + The preferred way would be `from ..utils import some_function`. + """ + # An import is relative if its level is greater than 0. + # e.g., from . import foo (level=1), from .. import bar (level=2) + # We only want to check absolute imports, so we skip relative ones. + if node.level and node.level > 0: return - package_path = os.path.dirname(current_file) - package_name = os.path.basename(package_path) + # Get the fully qualified name of the module being linted. + # For a file at '.../comfy/nodes/common.py', this will be 'comfy.nodes.common'. + module_qname = node.root().qname() - if node.modname.startswith(package_name) and package_name in ['comfy', 'comfy_extras']: - import_parts = node.modname.split('.') + # `node.modname` is the module name in the `from` statement. + # For `from comfy.utils import x`, `modname` is `comfy.utils`. + imported_modname = node.modname + if not imported_modname: + return - if import_parts[0] == package_name: - self.add_message('absolute-import-used', node=node, args=(node.modname,)) + # We are only interested in modules within 'comfy' or 'comfy_extras'. + # We determine this by looking at the first part of the qualified name. + current_top_package = module_qname.split('.')[0] + if current_top_package not in ['comfy', 'comfy_extras']: + return + + imported_top_package = imported_modname.split('.')[0] + + # If the top-level package of the imported module is the same as the + # current module's top-level package, it's an internal absolute import. + if imported_top_package == current_top_package: + self.add_message('absolute-import-used', node=node, args=(imported_modname,)) def register(linter: "PyLinter") -> None: - linter.register_checker(AbsoluteImportChecker(linter)) + linter.register_checker(AbsoluteImportChecker(linter)) \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index fdcf57430..c00686f25 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,50 +1,25 @@ -import sys -import time - import logging import multiprocessing import os import pathlib +import subprocess +import sys +import time +import urllib +from typing import Tuple, List + import pytest import requests -import socket -import subprocess -import urllib -from testcontainers.rabbitmq import RabbitMqContainer -from typing import Tuple, List from comfy.cli_args_types import Configuration logging.getLogger("pika").setLevel(logging.CRITICAL + 1) logging.getLogger("aio_pika").setLevel(logging.CRITICAL + 1) -logging.getLogger("testcontainers.core.container").setLevel(logging.WARNING) -logging.getLogger("testcontainers.core.waiting_utils").setLevel(logging.WARNING) # fixes issues with running the testcontainers rabbitmqcontainer on Windows os.environ["TC_HOST"] = "localhost" -def get_lan_ip(): - """ - Finds the host's IP address on the LAN it's connected to. - - Returns: - str: The IP address of the host on the LAN. - """ - # Create a dummy socket - s = None - try: - # Connect to a dummy address (Here, Google's public DNS server) - # The actual connection is not made, but this allows finding out the LAN IP - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - s.connect(("8.8.8.8", 80)) - ip = s.getsockname()[0] - finally: - if s is not None: - s.close() - return ip - - def run_server(server_arguments: Configuration): from comfy.cmd.main import main from comfy.cli_args import args @@ -95,6 +70,11 @@ def has_gpu() -> bool: @pytest.fixture(scope="module", autouse=False, params=["ThreadPoolExecutor", "ProcessPoolExecutor"]) def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers: int = 1): from huggingface_hub import hf_hub_download + from testcontainers.rabbitmq import RabbitMqContainer + + logging.getLogger("testcontainers.core.container").setLevel(logging.WARNING) + logging.getLogger("testcontainers.core.waiting_utils").setLevel(logging.WARNING) + hf_hub_download("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors") hf_hub_download("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors") @@ -108,8 +88,8 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers frontend_command = [ "comfyui", - "--listen=0.0.0.0", - "--port=9001", + "--listen=127.0.0.1", + "--port=19001", "--cpu", "--distributed-queue-frontend", f"-w={str(tmp_path)}", @@ -122,7 +102,7 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers for i in range(num_workers): backend_command = [ "comfyui-worker", - f"--port={9002 + i}", + f"--port={19002 + i}", f"-w={str(tmp_path)}", f"--distributed-queue-connection-uri={connection_uri}", f"--executor-factory={executor_factory}" @@ -130,7 +110,7 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr)) try: - server_address = f"http://{get_lan_ip()}:9001" + server_address = f"http://127.0.0.1:19001" start_time = time.time() connected = False while time.time() - start_time < 60: diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py index cc0f40af7..b3055f129 100644 --- a/tests/inference/test_execution.py +++ b/tests/inference/test_execution.py @@ -15,7 +15,7 @@ from comfy.component_model.executor_types import SendSyncEvent, SendSyncData, Ex DependencyCycleError from comfy.distributed.server_stub import ServerStub from comfy.execution_context import context_add_custom_nodes -from comfy.graph_utils import GraphBuilder, Node +from comfy_execution.graph_utils import GraphBuilder, Node from comfy.nodes.package_typing import ExportedNodes current_test_name = ContextVar('current_test_name', default=None) diff --git a/tests/inference/testing_pack/flow_control.py b/tests/inference/testing_pack/flow_control.py index 1ef1cf803..f5d495f38 100644 --- a/tests/inference/testing_pack/flow_control.py +++ b/tests/inference/testing_pack/flow_control.py @@ -1,4 +1,4 @@ -from comfy.graph_utils import GraphBuilder, is_link +from comfy_execution.graph_utils import GraphBuilder, is_link from comfy.graph import ExecutionBlocker from .tools import VariantSupport diff --git a/tests/inference/testing_pack/specific_tests.py b/tests/inference/testing_pack/specific_tests.py index 9cb20d081..1f626cd15 100644 --- a/tests/inference/testing_pack/specific_tests.py +++ b/tests/inference/testing_pack/specific_tests.py @@ -3,7 +3,7 @@ import time import asyncio from comfy.utils import ProgressBar from .tools import VariantSupport -from comfy.graph_utils import GraphBuilder +from comfy_execution.graph_utils import GraphBuilder from comfy.comfy_types.node_typing import ComfyNodeABC from comfy.comfy_types import IO diff --git a/tests/inference/testing_pack/util.py b/tests/inference/testing_pack/util.py index 1a3fca0c9..17741c5f1 100644 --- a/tests/inference/testing_pack/util.py +++ b/tests/inference/testing_pack/util.py @@ -1,4 +1,4 @@ -from comfy.graph_utils import GraphBuilder +from comfy_execution.graph_utils import GraphBuilder from .tools import VariantSupport @VariantSupport() diff --git a/tests/main_pre_import_checker.py b/tests/main_pre_import_checker.py new file mode 100644 index 000000000..645c08588 --- /dev/null +++ b/tests/main_pre_import_checker.py @@ -0,0 +1,129 @@ +"""Pylint checker for ensuring main_pre is imported first.""" + +from typing import TYPE_CHECKING, Optional, Union + +import astroid +from astroid import nodes +from pylint.checkers import BaseChecker + +if TYPE_CHECKING: + from pylint.lint import PyLinter + + +class MainPreImportOrderChecker(BaseChecker): + """ + Ensures that imports from 'comfy.cmd.main_pre' or similar setup modules + occur before any other relative imports or imports from the 'comfy' package. + + This is important for code that relies on setup being performed by 'main_pre' + before other modules from the same package are imported. + """ + + name = 'main-pre-import-order' + msgs = { + 'W0002': ( + 'Setup import %s must be placed before other package imports like %s.', + 'main-pre-import-not-first', + "To ensure necessary setup is performed, 'comfy.cmd.main_pre' or similar " + "setup imports must precede other relative imports or imports from the " + "'comfy' family of packages." + ), + } + + def _is_main_pre_import(self, stmt: Union[nodes.Import, nodes.ImportFrom]) -> bool: + """Checks if an import statement is for 'comfy.cmd.main_pre'.""" + if isinstance(stmt, nodes.Import): + for name, _ in stmt.names: + if name == 'comfy.cmd.main_pre' or name.startswith('comfy.cmd.main_pre.'): + return True + return False + + if isinstance(stmt, nodes.ImportFrom): + qname: Optional[str] = None + if stmt.level == 0: + qname = stmt.modname + else: + try: + # Attempt to resolve the relative import to a fully qualified name + imported_module = stmt.do_import_module() + qname = imported_module.qname() + except astroid.AstroidError: + # Fallback for unresolved relative imports, check the literal module name + if stmt.modname and stmt.modname.endswith('.main_pre'): + return True + # Heuristic for `from ..cmd import main_pre` in `comfy/entrypoints/*` + if stmt.modname == 'cmd' and stmt.root().qname().startswith('comfy'): + for name, _ in stmt.names: + if name == 'main_pre': + return True + + if not qname: + return False + + # from comfy.cmd import main_pre + if qname == 'comfy.cmd': + for name, _ in stmt.names: + if name == 'main_pre': + return True + + # from comfy.cmd.main_pre import ... OR from a.b.c.main_pre import ... + if qname == 'comfy.cmd.main_pre' or qname.endswith('.main_pre'): + return True + + return False + + def _is_other_relevant_import(self, stmt: Union[nodes.Import, nodes.ImportFrom]) -> bool: + """ + Checks if an import is a relative import or an import from + the 'comfy' package family, and is not a 'main_pre' import. + """ + if self._is_main_pre_import(stmt): + return False + + if isinstance(stmt, nodes.ImportFrom): + if stmt.level and stmt.level > 0: # Any relative import + return True + if stmt.modname and stmt.modname.startswith('comfy'): + return True + + if isinstance(stmt, nodes.Import): + for name, _ in stmt.names: + if name.startswith('comfy'): + return True + + return False + + def visit_module(self, node: nodes.Module) -> None: + """Checks the order of imports within a module.""" + imports = [ + stmt for stmt in node.body + if isinstance(stmt, (nodes.Import, nodes.ImportFrom)) + ] + + main_pre_import_node = None + for stmt in imports: + if self._is_main_pre_import(stmt): + main_pre_import_node = stmt + break + + # If there's no main_pre import, there's nothing to check. + if not main_pre_import_node: + return + + for stmt in imports: + # We only care about imports that appear before the main_pre import + if stmt.fromlineno >= main_pre_import_node.fromlineno: + break + + if self._is_other_relevant_import(stmt): + self.add_message( + 'main-pre-import-not-first', + node=main_pre_import_node, + args=(main_pre_import_node.as_string(), stmt.as_string()) + ) + return # Report once per file and exit to avoid spam + + +def register(linter: "PyLinter") -> None: + """This function is required for a Pylint plugin.""" + linter.register_checker(MainPreImportOrderChecker(linter)) diff --git a/tests/unit/prompt_server_test/user_manager_test.py b/tests/unit/prompt_server_test/user_manager_test.py new file mode 100644 index 000000000..2df82bff3 --- /dev/null +++ b/tests/unit/prompt_server_test/user_manager_test.py @@ -0,0 +1,292 @@ +import os +from unittest.mock import patch + +import pytest +from aiohttp import web + +from comfy.app.user_manager import UserManager + +pytestmark = ( + pytest.mark.asyncio +) # This applies the asyncio mark to all test functions in the module + + +@pytest.fixture +def user_manager(tmp_path): + um = UserManager() + um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join( + tmp_path, file + ) if file else tmp_path + return um + + +@pytest.fixture +def app(user_manager): + app = web.Application() + routes = web.RouteTableDef() + user_manager.add_routes(routes) + app.add_routes(routes) + return app + + +async def test_listuserdata_empty_directory(aiohttp_client, app, tmp_path): + client = await aiohttp_client(app) + resp = await client.get("/userdata?dir=test_dir") + assert resp.status == 404 + + +async def test_listuserdata_with_files(aiohttp_client, app, tmp_path): + os.makedirs(tmp_path / "test_dir") + with open(tmp_path / "test_dir" / "file1.txt", "w") as f: + f.write("test content") + + client = await aiohttp_client(app) + resp = await client.get("/userdata?dir=test_dir") + assert resp.status == 200 + assert await resp.json() == ["file1.txt"] + + +async def test_listuserdata_recursive(aiohttp_client, app, tmp_path): + os.makedirs(tmp_path / "test_dir" / "subdir") + with open(tmp_path / "test_dir" / "file1.txt", "w") as f: + f.write("test content") + with open(tmp_path / "test_dir" / "subdir" / "file2.txt", "w") as f: + f.write("test content") + + client = await aiohttp_client(app) + resp = await client.get("/userdata?dir=test_dir&recurse=true") + assert resp.status == 200 + assert set(await resp.json()) == {"file1.txt", "subdir/file2.txt"} + + +async def test_listuserdata_full_info(aiohttp_client, app, tmp_path): + os.makedirs(tmp_path / "test_dir") + with open(tmp_path / "test_dir" / "file1.txt", "w") as f: + f.write("test content") + + client = await aiohttp_client(app) + resp = await client.get("/userdata?dir=test_dir&full_info=true") + assert resp.status == 200 + result = await resp.json() + assert len(result) == 1 + assert result[0]["path"] == "file1.txt" + assert "size" in result[0] + assert "modified" in result[0] + + +async def test_listuserdata_split_path(aiohttp_client, app, tmp_path): + os.makedirs(tmp_path / "test_dir" / "subdir") + with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f: + f.write("test content") + + client = await aiohttp_client(app) + resp = await client.get("/userdata?dir=test_dir&recurse=true&split=true") + assert resp.status == 200 + assert await resp.json() == [["subdir/file1.txt", "subdir", "file1.txt"]] + + +async def test_listuserdata_invalid_directory(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.get("/userdata?dir=") + assert resp.status == 400 + + +async def test_listuserdata_normalized_separator(aiohttp_client, app, tmp_path): + os_sep = "\\" + with patch("os.sep", os_sep): + with patch("os.path.sep", os_sep): + os.makedirs(tmp_path / "test_dir" / "subdir") + with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f: + f.write("test content") + + client = await aiohttp_client(app) + resp = await client.get("/userdata?dir=test_dir&recurse=true") + assert resp.status == 200 + result = await resp.json() + assert len(result) == 1 + assert "/" in result[0] # Ensure forward slash is used + assert "\\" not in result[0] # Ensure backslash is not present + assert result[0] == "subdir/file1.txt" + + # Test with full_info + resp = await client.get( + "/userdata?dir=test_dir&recurse=true&full_info=true" + ) + assert resp.status == 200 + result = await resp.json() + assert len(result) == 1 + assert "/" in result[0]["path"] # Ensure forward slash is used + assert "\\" not in result[0]["path"] # Ensure backslash is not present + assert result[0]["path"] == "subdir/file1.txt" + + +async def test_post_userdata_new_file(aiohttp_client, app, tmp_path): + client = await aiohttp_client(app) + content = b"test content" + resp = await client.post("/userdata/test.txt", data=content) + + assert resp.status == 200 + assert await resp.text() == '"test.txt"' + + # Verify file was created with correct content + with open(tmp_path / "test.txt", "rb") as f: + assert f.read() == content + + +async def test_post_userdata_overwrite_existing(aiohttp_client, app, tmp_path): + # Create initial file + with open(tmp_path / "test.txt", "w") as f: + f.write("initial content") + + client = await aiohttp_client(app) + new_content = b"updated content" + resp = await client.post("/userdata/test.txt", data=new_content) + + assert resp.status == 200 + assert await resp.text() == '"test.txt"' + + # Verify file was overwritten + with open(tmp_path / "test.txt", "rb") as f: + assert f.read() == new_content + + +async def test_post_userdata_no_overwrite(aiohttp_client, app, tmp_path): + # Create initial file + with open(tmp_path / "test.txt", "w") as f: + f.write("initial content") + + client = await aiohttp_client(app) + resp = await client.post("/userdata/test.txt?overwrite=false", data=b"new content") + + assert resp.status == 409 + + # Verify original content unchanged + with open(tmp_path / "test.txt", "r") as f: + assert f.read() == "initial content" + + +async def test_post_userdata_full_info(aiohttp_client, app, tmp_path): + client = await aiohttp_client(app) + content = b"test content" + resp = await client.post("/userdata/test.txt?full_info=true", data=content) + + assert resp.status == 200 + result = await resp.json() + assert result["path"] == "test.txt" + assert result["size"] == len(content) + assert "modified" in result + + +async def test_move_userdata(aiohttp_client, app, tmp_path): + # Create initial file + with open(tmp_path / "source.txt", "w") as f: + f.write("test content") + + client = await aiohttp_client(app) + resp = await client.post("/userdata/source.txt/move/dest.txt") + + assert resp.status == 200 + assert await resp.text() == '"dest.txt"' + + # Verify file was moved + assert not os.path.exists(tmp_path / "source.txt") + with open(tmp_path / "dest.txt", "r") as f: + assert f.read() == "test content" + + +async def test_move_userdata_no_overwrite(aiohttp_client, app, tmp_path): + # Create source and destination files + with open(tmp_path / "source.txt", "w") as f: + f.write("source content") + with open(tmp_path / "dest.txt", "w") as f: + f.write("destination content") + + client = await aiohttp_client(app) + resp = await client.post("/userdata/source.txt/move/dest.txt?overwrite=false") + + assert resp.status == 409 + + # Verify files remain unchanged + with open(tmp_path / "source.txt", "r") as f: + assert f.read() == "source content" + with open(tmp_path / "dest.txt", "r") as f: + assert f.read() == "destination content" + + +async def test_move_userdata_full_info(aiohttp_client, app, tmp_path): + # Create initial file + with open(tmp_path / "source.txt", "w") as f: + f.write("test content") + + client = await aiohttp_client(app) + resp = await client.post("/userdata/source.txt/move/dest.txt?full_info=true") + + assert resp.status == 200 + result = await resp.json() + assert result["path"] == "dest.txt" + assert result["size"] == len("test content") + assert "modified" in result + + # Verify file was moved + assert not os.path.exists(tmp_path / "source.txt") + with open(tmp_path / "dest.txt", "r") as f: + assert f.read() == "test content" + + +async def test_listuserdata_v2_empty_root(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.get("/v2/userdata") + assert resp.status == 200 + assert await resp.json() == [] + + +async def test_listuserdata_v2_nonexistent_subdirectory(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.get("/v2/userdata?path=does_not_exist") + assert resp.status == 404 + + +async def test_listuserdata_v2_default(aiohttp_client, app, tmp_path): + os.makedirs(tmp_path / "test_dir" / "subdir") + (tmp_path / "test_dir" / "file1.txt").write_text("content") + (tmp_path / "test_dir" / "subdir" / "file2.txt").write_text("content") + + client = await aiohttp_client(app) + resp = await client.get("/v2/userdata?path=test_dir") + assert resp.status == 200 + data = await resp.json() + file_paths = {item["path"] for item in data if item["type"] == "file"} + assert file_paths == {"test_dir/file1.txt", "test_dir/subdir/file2.txt"} + + +async def test_listuserdata_v2_normalized_separators(aiohttp_client, app, tmp_path, monkeypatch): + # Force backslash as os separator + monkeypatch.setattr(os, 'sep', '\\') + monkeypatch.setattr(os.path, 'sep', '\\') + os.makedirs(tmp_path / "test_dir" / "subdir") + (tmp_path / "test_dir" / "subdir" / "file1.txt").write_text("x") + + client = await aiohttp_client(app) + resp = await client.get("/v2/userdata?path=test_dir") + assert resp.status == 200 + data = await resp.json() + for item in data: + assert "/" in item["path"] + assert "\\" not in item["path"] + + +async def test_listuserdata_v2_url_encoded_path(aiohttp_client, app, tmp_path): + # Create a directory with a space in its name and a file inside + os.makedirs(tmp_path / "my dir") + (tmp_path / "my dir" / "file.txt").write_text("content") + + client = await aiohttp_client(app) + # Use URL-encoded space in path parameter + resp = await client.get("/v2/userdata?path=my%20dir&recurse=false") + assert resp.status == 200 + data = await resp.json() + assert len(data) == 1 + entry = data[0] + assert entry["name"] == "file.txt" + # Ensure the path is correctly decoded and uses forward slash + assert entry["path"] == "my dir/file.txt" diff --git a/tests/unit/test_openapi_nodes.py b/tests/unit/test_openapi_nodes.py index 1fee71e17..90ccbf99e 100644 --- a/tests/unit/test_openapi_nodes.py +++ b/tests/unit/test_openapi_nodes.py @@ -122,7 +122,7 @@ def test_bool_request_parameter(): assert v == True -def test_string_enum_request_parameter(): +async def test_string_enum_request_parameter(): nt = StringEnumRequestParameter.INPUT_TYPES() assert nt is not None n = StringEnumRequestParameter() @@ -155,8 +155,9 @@ def test_string_enum_request_parameter(): } from comfy.cmd.execution import validate_inputs validated: dict[str, ValidateInputsTuple] = {} - validated["1"] = validate_inputs(prompt, "1", validated) - validated["2"] = validate_inputs(prompt, "2", validated) + prompt_id = str(uuid.uuid4()) + validated["1"] = await validate_inputs(prompt_id, prompt, "1", validated) + validated["2"] = await validate_inputs(prompt_id, prompt, "2", validated) assert validated["2"].valid