mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
packaging fixes
- enable user db - fix main_pre order everywhere - fix absolute to relative imports everywhere - async better supported
This commit is contained in:
parent
c086c5e005
commit
96b4e04315
@ -82,7 +82,7 @@ limit-inference-results=100
|
||||
|
||||
# List of plugins (as comma separated values of python module names) to load,
|
||||
# usually to register additional checkers.
|
||||
load-plugins=tests.absolute_import_checker
|
||||
load-plugins=tests.absolute_import_checker,tests.main_pre_import_checker
|
||||
|
||||
# Pickle collected data for later comparisons.
|
||||
persistent=yes
|
||||
@ -678,7 +678,7 @@ disable=raw-checker-failed,
|
||||
# either give multiple identifier separated by comma (,) or put this option
|
||||
# multiple time (only on the command line, not in the configuration file where
|
||||
# it should appear only once). See also the "--disable" option for examples.
|
||||
enable=absolute-import-used
|
||||
enable=
|
||||
|
||||
[METHOD_ARGS]
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ script_location = alembic_db
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory.
|
||||
prepend_sys_path = .
|
||||
# prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
@ -63,7 +63,7 @@ version_path_separator = os
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
sqlalchemy.url = sqlite:///user/comfyui.db
|
||||
sqlalchemy.url = sqlite:///./comfyui.db
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
0
comfy/alembic_db/__init__.py
Normal file
0
comfy/alembic_db/__init__.py
Normal file
@ -1,3 +1,4 @@
|
||||
# pylint: disable=no-member
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
|
||||
@ -7,8 +8,7 @@ from alembic import context
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
|
||||
from comfy.app.database.models import Base
|
||||
from ..app.database.models import Base
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
@ -1,8 +1,10 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from importlib.resources import files
|
||||
|
||||
from ...cli_args import args
|
||||
from ...component_model.files import get_package_as_path
|
||||
|
||||
Session = None
|
||||
|
||||
@ -15,6 +17,7 @@ from sqlalchemy.orm import sessionmaker
|
||||
|
||||
_DB_AVAILABLE = True
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def dependencies_available():
|
||||
"""
|
||||
@ -32,9 +35,8 @@ def can_create_session():
|
||||
|
||||
|
||||
def get_alembic_config():
|
||||
root_path = os.path.join(os.path.dirname(__file__), "../..")
|
||||
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
|
||||
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
|
||||
config_path = str(files("comfy") / "alembic.ini")
|
||||
scripts_path = get_package_as_path("comfy.alembic_db")
|
||||
|
||||
config = Config(config_path)
|
||||
config.set_main_option("script_location", scripts_path)
|
||||
@ -53,7 +55,7 @@ def get_db_path():
|
||||
|
||||
def init_db():
|
||||
db_url = args.database_url
|
||||
logging.debug(f"Database URL: {db_url}")
|
||||
logger.debug(f"Database URL: {db_url}")
|
||||
db_path = get_db_path()
|
||||
db_exists = os.path.exists(db_path)
|
||||
|
||||
@ -70,7 +72,7 @@ def init_db():
|
||||
target_rev = script.get_current_head()
|
||||
|
||||
if target_rev is None:
|
||||
logging.warning("No target revision found.")
|
||||
logger.debug("No target revision found.")
|
||||
elif current_rev != target_rev:
|
||||
# Backup the database pre upgrade
|
||||
backup_path = db_path + ".bkp"
|
||||
@ -81,13 +83,13 @@ def init_db():
|
||||
|
||||
try:
|
||||
command.upgrade(config, target_rev)
|
||||
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
|
||||
logger.info(f"Database upgraded from {current_rev} to {target_rev}")
|
||||
except Exception as e:
|
||||
if backup_path:
|
||||
# Restore the database from backup if upgrade fails
|
||||
shutil.copy(backup_path, db_path)
|
||||
os.remove(backup_path)
|
||||
logging.exception("Error upgrading database: ")
|
||||
logger.exception("Error upgrading database: ")
|
||||
raise e
|
||||
|
||||
global Session
|
||||
|
||||
@ -5,7 +5,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch
|
||||
from ..ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, \
|
||||
get_2d_sincos_pos_embed_torch
|
||||
|
||||
|
||||
class ControlNetEmbedder(nn.Module):
|
||||
|
||||
@ -11,7 +11,7 @@ import configargparse as argparse
|
||||
from . import __version__
|
||||
from . import options
|
||||
from .cli_args_types import LatentPreviewMethod, Configuration, ConfigurationExtender, EnumAction, \
|
||||
EnhancedConfigArgParser, PerformanceFeature, is_valid_directory
|
||||
EnhancedConfigArgParser, PerformanceFeature, is_valid_directory, db_config
|
||||
|
||||
# todo: move this
|
||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||
@ -261,8 +261,8 @@ def _create_parser() -> EnhancedConfigArgParser:
|
||||
help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
|
||||
)
|
||||
|
||||
parser.add_argument("--database-url", type=str, default=f"sqlite:///:memory:", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||
|
||||
default_db_url = db_config()
|
||||
parser.add_argument("--database-url", type=str, default=default_db_url, help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||
parser.add_argument("--workflows", type=str, nargs='+', default=[], help="Execute the API workflow(s) specified in the provided files. For each workflow, its outputs will be printed to a line to standard out. Application logging will be redirected to standard error. Use `-` to signify standard in.")
|
||||
|
||||
# now give plugins a chance to add configuration
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, List, Callable, Any, Union, Mapping, NamedTuple
|
||||
|
||||
@ -20,6 +21,22 @@ class LatentPreviewMethod(enum.Enum):
|
||||
ConfigObserver = Callable[[str, Any], None]
|
||||
|
||||
|
||||
def db_config() -> str:
|
||||
from .vendor.appdirs import user_data_dir
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
try:
|
||||
data_dir = user_data_dir(appname="comfyui")
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
db_path = os.path.join(data_dir, "comfy.db")
|
||||
default_db_url = f"sqlite:///{db_path}"
|
||||
except Exception as e:
|
||||
# Fallback to an in-memory database if the user directory can't be accessed
|
||||
logger.warning(f"Could not determine user data directory for database, falling back to in-memory: {e}")
|
||||
default_db_url = "sqlite:///:memory:"
|
||||
return default_db_url
|
||||
|
||||
|
||||
def is_valid_directory(path: str) -> str:
|
||||
"""Validate if the given path is a directory, and check permissions."""
|
||||
if not os.path.exists(path):
|
||||
@ -261,7 +278,7 @@ class Configuration(dict):
|
||||
self.front_end_version: str = "comfyanonymous/ComfyUI@latest"
|
||||
self.front_end_root: Optional[str] = None
|
||||
self.comfy_api_base: str = "https://api.comfy.org"
|
||||
self.database_url: str = "sqlite:///:memory:"
|
||||
self.database_url: str = db_config()
|
||||
|
||||
for key, value in kwargs.items():
|
||||
self[key] = value
|
||||
|
||||
@ -14,11 +14,11 @@ from opentelemetry import context, propagate
|
||||
from opentelemetry.context import Context, attach, detach
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
|
||||
from ..cmd.main_pre import tracer
|
||||
from .client_types import V1QueuePromptResponse
|
||||
from ..api.components.schema.prompt import PromptDict
|
||||
from ..cli_args_types import Configuration
|
||||
from ..cmd.folder_paths import init_default_paths # pylint: disable=import-error
|
||||
from ..cmd.main_pre import tracer
|
||||
from ..component_model.executor_types import ExecutorToClientProgress
|
||||
from ..component_model.make_mutable import make_mutable
|
||||
from ..distributed.executors import ContextVarExecutor
|
||||
@ -97,7 +97,7 @@ async def __execute_prompt(
|
||||
else:
|
||||
prompt_executor.server = progress_handler
|
||||
|
||||
prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id},
|
||||
await prompt_executor.execute_async(prompt_mut, prompt_id, {"client_id": client_id},
|
||||
execute_outputs=validation_tuple.good_output_node_ids)
|
||||
return prompt_executor.outputs_ui
|
||||
except Exception as exc_info:
|
||||
@ -195,7 +195,7 @@ class Comfy:
|
||||
if isinstance(prompt, str):
|
||||
prompt = json.loads(prompt)
|
||||
if isinstance(prompt, dict):
|
||||
from comfy.api.components.schema.prompt import Prompt
|
||||
from ..api.components.schema.prompt import Prompt
|
||||
prompt = Prompt.validate(prompt)
|
||||
outputs = await self.queue_prompt(prompt)
|
||||
return V1QueuePromptResponse(urls=[], outputs=outputs)
|
||||
|
||||
@ -19,14 +19,15 @@ from typing import List, Optional, Tuple, Literal
|
||||
import torch
|
||||
from opentelemetry.trace import get_current_span, StatusCode, Status
|
||||
|
||||
# order matters
|
||||
from .main_pre import tracer
|
||||
|
||||
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, \
|
||||
DependencyAwareCache, \
|
||||
BasicCache
|
||||
# order matters
|
||||
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||
from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||
from comfy_execution.utils import CurrentNodeContext
|
||||
from .main_pre import tracer
|
||||
from .. import interruption
|
||||
from .. import model_management
|
||||
from ..cli_args import args
|
||||
@ -37,7 +38,7 @@ from ..component_model.executor_types import ExecutorToClientProgress, Validatio
|
||||
HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage
|
||||
from ..component_model.files import canonicalize_path
|
||||
from ..component_model.module_property import create_module_properties
|
||||
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
|
||||
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus, ExecutionStatusAsDict
|
||||
from ..execution_context import context_execute_node, context_execute_prompt
|
||||
from ..execution_ext import should_panic_on_exception
|
||||
from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
|
||||
@ -388,7 +389,7 @@ def format_value(x) -> FormattedValue:
|
||||
return str(x.__class__)
|
||||
|
||||
|
||||
async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches, _node_id: str, extra_data: dict, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple:
|
||||
async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches, node_id: str, extra_data: dict, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple:
|
||||
"""
|
||||
|
||||
:param server:
|
||||
@ -402,8 +403,8 @@ async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, ca
|
||||
:param pending_subgraph_results:
|
||||
:return:
|
||||
"""
|
||||
with context_execute_node(_node_id):
|
||||
return _execute(server, dynprompt, caches, _node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
|
||||
with context_execute_node(node_id):
|
||||
return await _execute(server, dynprompt, caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
|
||||
|
||||
|
||||
async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple:
|
||||
@ -516,7 +517,7 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
|
||||
unblock()
|
||||
|
||||
asyncio.create_task(await_completion())
|
||||
return (ExecutionResult.PENDING, None, None)
|
||||
return RecursiveExecutionTuple(ExecutionResult.PENDING, None, None)
|
||||
if len(output_ui) > 0:
|
||||
caches.ui.set(unique_id, {
|
||||
"meta": {
|
||||
@ -685,11 +686,6 @@ class PromptExecutor:
|
||||
if ex is not None and self.raise_exceptions:
|
||||
raise ex
|
||||
|
||||
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
asyncio_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(asyncio_loop)
|
||||
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
|
||||
|
||||
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
# torchao and potentially other optimization approaches break when the models are created in inference mode
|
||||
# todo: this should really be backpropagated to code which creates ModelPatchers via lazy evaluation rather than globally checked here
|
||||
@ -1206,7 +1202,7 @@ class PromptQueue(AbstractPromptQueue):
|
||||
self.server.queue_updated()
|
||||
return copy.deepcopy(item_with_future.queue_tuple), task_id
|
||||
|
||||
def task_done(self, item_id: str, outputs: dict,
|
||||
def task_done(self, item_id: str, outputs: HistoryResultDict,
|
||||
status: Optional[ExecutionStatus]):
|
||||
history_result = outputs
|
||||
with self.mutex:
|
||||
@ -1215,9 +1211,9 @@ class PromptQueue(AbstractPromptQueue):
|
||||
if len(self.history) > MAXIMUM_HISTORY_SIZE:
|
||||
self.history.pop(next(iter(self.history)))
|
||||
|
||||
status_dict: Optional[dict] = None
|
||||
status_dict = None
|
||||
if status is not None:
|
||||
status_dict = copy.deepcopy(ExecutionStatus(*status)._asdict())
|
||||
status_dict: Optional[ExecutionStatusAsDict] = status.as_dict()
|
||||
|
||||
outputs_ = history_result["outputs"]
|
||||
# Remove sensitive data from extra_data before storing in history
|
||||
@ -1225,11 +1221,13 @@ class PromptQueue(AbstractPromptQueue):
|
||||
if sensitive_val in prompt[3]:
|
||||
prompt[3].pop(sensitive_val)
|
||||
|
||||
self.history[prompt[1]] = {
|
||||
history_entry: HistoryEntry = {
|
||||
"prompt": prompt,
|
||||
"outputs": copy.deepcopy(outputs_),
|
||||
'status': status_dict,
|
||||
}
|
||||
if status_dict is not None:
|
||||
history_entry["status"] = status_dict
|
||||
self.history[prompt[1]] = history_entry
|
||||
self.history[prompt[1]].update(history_result)
|
||||
self.server.queue_updated()
|
||||
if queue_item.completed:
|
||||
|
||||
@ -10,20 +10,22 @@ import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from comfy.component_model.entrypoints_common import configure_application_paths, executor_from_args
|
||||
# main_pre must be the earliest import
|
||||
from .main_pre import args
|
||||
|
||||
from . import hook_breaker_ac10a0
|
||||
from .extra_model_paths import load_extra_path_config
|
||||
# main_pre must be the earliest import since it suppresses some spurious warnings
|
||||
from .main_pre import args
|
||||
from .. import model_management
|
||||
from ..analytics.analytics import initialize_event_tracking
|
||||
from ..cmd import cuda_malloc
|
||||
from ..cmd import folder_paths
|
||||
from ..cmd import server as server_module
|
||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||
from ..component_model.entrypoints_common import configure_application_paths, executor_from_args
|
||||
from ..distributed.distributed_prompt_queue import DistributedPromptQueue
|
||||
from ..distributed.server_stub import ServerStub
|
||||
from ..nodes.package import import_all_nodes_in_workspace
|
||||
from ..nodes_context import get_nodes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -42,6 +44,10 @@ def cuda_malloc_warning():
|
||||
|
||||
|
||||
def prompt_worker(q: AbstractPromptQueue, server_instance: server_module.PromptServer):
|
||||
asyncio.run(_prompt_worker(q, server_instance))
|
||||
|
||||
|
||||
async def _prompt_worker(q: AbstractPromptQueue, server_instance: server_module.PromptServer):
|
||||
from ..cmd import execution
|
||||
from ..component_model import queue_types
|
||||
from .. import model_management
|
||||
@ -68,7 +74,7 @@ def prompt_worker(q: AbstractPromptQueue, server_instance: server_module.PromptS
|
||||
prompt_id = item[1]
|
||||
server_instance.last_prompt_id = prompt_id
|
||||
|
||||
e.execute(item[2], prompt_id, item[3], item[4])
|
||||
await e.execute_async(item[2], prompt_id, item[3], item[4])
|
||||
need_gc = True
|
||||
q.task_done(item_id,
|
||||
e.history_result,
|
||||
@ -174,17 +180,16 @@ async def _start_comfyui(from_script_dir: Optional[Path] = None):
|
||||
for config_path in itertools.chain(*args.extra_model_paths_config):
|
||||
load_extra_path_config(config_path)
|
||||
|
||||
# always create directories when started interactively
|
||||
folder_paths.create_directories()
|
||||
if args.create_directories:
|
||||
# then, import and exit
|
||||
import_all_nodes_in_workspace(raise_on_failure=False)
|
||||
folder_paths.create_directories()
|
||||
exit(0)
|
||||
|
||||
setup_database()
|
||||
elif args.quick_test_for_ci:
|
||||
import_all_nodes_in_workspace(raise_on_failure=True)
|
||||
exit(0)
|
||||
|
||||
if args.windows_standalone_build:
|
||||
folder_paths.create_directories()
|
||||
try:
|
||||
from . import new_updater
|
||||
new_updater.update_windows_updater()
|
||||
@ -198,7 +203,7 @@ async def _start_comfyui(from_script_dir: Optional[Path] = None):
|
||||
|
||||
# at this stage, it's safe to import nodes
|
||||
hook_breaker_ac10a0.save_functions()
|
||||
server.nodes = import_all_nodes_in_workspace()
|
||||
server.nodes = get_nodes()
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
# as a side effect, this also populates the nodes for execution
|
||||
|
||||
@ -221,6 +226,7 @@ async def _start_comfyui(from_script_dir: Optional[Path] = None):
|
||||
|
||||
server.add_routes()
|
||||
cuda_malloc_warning()
|
||||
setup_database()
|
||||
|
||||
# in a distributed setting, the default prompt worker will not be able to send execution events via the websocket
|
||||
worker_thread_server = server if not distributed else ServerStub()
|
||||
@ -254,15 +260,8 @@ async def _start_comfyui(from_script_dir: Optional[Path] = None):
|
||||
logger.debug(f"Setting input directory to: {input_dir}")
|
||||
folder_paths.set_input_directory(input_dir)
|
||||
|
||||
if args.quick_test_for_ci:
|
||||
# for CI purposes, try importing all the nodes
|
||||
import_all_nodes_in_workspace(raise_on_failure=True)
|
||||
return
|
||||
else:
|
||||
# we no longer lazily load nodes. we'll do it now for the sake of creating directories
|
||||
import_all_nodes_in_workspace(raise_on_failure=False)
|
||||
# now that nodes are loaded, create more directories if appropriate
|
||||
folder_paths.create_directories()
|
||||
# now that nodes are loaded, create directories
|
||||
folder_paths.create_directories()
|
||||
|
||||
if len(args.workflows) > 0:
|
||||
configure_application_paths(args)
|
||||
|
||||
@ -1169,7 +1169,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
await runner.setup()
|
||||
|
||||
if 'tls_keyfile' in args or 'tls_certfile' in args:
|
||||
raise ValueError("Use caddy instead of aiohttp to serve https by setting up a reverse proxy. See README.md")
|
||||
logger.warning("Use caddy instead of aiohttp to serve https by setting up a reverse proxy. See README.md")
|
||||
|
||||
def is_ipv4(address: str, *args):
|
||||
try:
|
||||
|
||||
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import typing
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
from .executor_types import HistoryResultDict
|
||||
from .queue_types import QueueTuple, HistoryEntry, QueueItem, Flags, ExecutionStatus, TaskInvocation, AbstractPromptQueueGetCurrentQueueItems
|
||||
|
||||
|
||||
@ -42,7 +43,7 @@ class AbstractPromptQueue(metaclass=ABCMeta):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def task_done(self, item_id: str, outputs: dict,
|
||||
def task_done(self, item_id: str, outputs: HistoryResultDict,
|
||||
status: typing.Optional[ExecutionStatus]):
|
||||
"""
|
||||
Signals to the user interface that the task with the specified id is completed
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
from enum import Enum
|
||||
from typing import NamedTuple, Optional, List, Literal, Sequence
|
||||
from typing import Tuple
|
||||
@ -24,6 +25,13 @@ class ExecutionStatus(NamedTuple):
|
||||
completed: bool
|
||||
messages: List[str]
|
||||
|
||||
def as_dict(self) -> ExecutionStatusAsDict:
|
||||
return {
|
||||
"status_str": self.status_str,
|
||||
"completed": self.completed,
|
||||
"messages": copy.copy(self.messages),
|
||||
}
|
||||
|
||||
|
||||
class ExecutionError(RuntimeError):
|
||||
def __init__(self, task_id: int | str, status: Optional[ExecutionStatus] = None, exceptions: Optional[Sequence[Exception]] = None, *args):
|
||||
|
||||
@ -13,12 +13,12 @@ from aio_pika import connect_robust
|
||||
from aio_pika.abc import AbstractConnection, AbstractChannel
|
||||
from aio_pika.patterns import JsonRPC
|
||||
|
||||
from ..cmd.main_pre import tracer
|
||||
from .distributed_progress import ProgressHandlers
|
||||
from .distributed_types import RpcRequest, RpcReply
|
||||
from .history import History
|
||||
from .server_stub import ServerStub
|
||||
from ..auth.permissions import jwt_decode
|
||||
from ..cmd.main_pre import tracer
|
||||
from ..cmd.server import PromptServer
|
||||
from ..component_model.abstract_prompt_queue import AsyncAbstractPromptQueue, AbstractPromptQueue
|
||||
from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData, HistoryResultDict
|
||||
|
||||
@ -10,12 +10,12 @@ from aio_pika.patterns import JsonRPC
|
||||
from aiohttp import web
|
||||
from aiormq import AMQPConnectionError
|
||||
|
||||
from ..cmd.main_pre import tracer
|
||||
from .executors import ContextVarExecutor
|
||||
from .distributed_progress import DistributedExecutorToClientProgress
|
||||
from .distributed_types import RpcRequest, RpcReply
|
||||
from .process_pool_executor import ProcessPoolExecutor
|
||||
from ..client.embedded_comfy_client import Comfy
|
||||
from ..cmd.main_pre import tracer
|
||||
from ..component_model.queue_types import ExecutionStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -2,7 +2,6 @@ import asyncio
|
||||
|
||||
from ..cmd.main_pre import args
|
||||
from ..component_model.entrypoints_common import configure_application_paths, executor_from_args
|
||||
from ..distributed.executors import ContextVarExecutor, ContextVarProcessPoolExecutor
|
||||
|
||||
|
||||
async def main():
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
import torch
|
||||
from comfy.text_encoders.bert import BertAttention
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
|
||||
from ..ldm.modules.attention import optimized_attention_for_device
|
||||
from ..model_management import cast_to_device
|
||||
from ..text_encoders.bert import BertAttention
|
||||
|
||||
|
||||
class Dino2AttentionOutput(torch.nn.Module):
|
||||
@ -29,7 +30,7 @@ class LayerScale(torch.nn.Module):
|
||||
self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x):
|
||||
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
|
||||
return x * cast_to_device(self.lambda1, x.device, x.dtype)
|
||||
|
||||
|
||||
class SwiGLUFFN(torch.nn.Module):
|
||||
@ -117,7 +118,7 @@ class Dino2Embeddings(torch.nn.Module):
|
||||
x = self.patch_embeddings(pixel_values)
|
||||
# TODO: mask_token?
|
||||
x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
|
||||
x = x + cast_to_device(self.position_embeddings, x.device, x.dtype)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from importlib.abc import Traversable # pylint: disable=no-name-in-module
|
||||
from importlib.resources import files
|
||||
from pathlib import Path
|
||||
|
||||
@ -10,7 +9,7 @@ KNOWN_CHAT_TEMPLATES = {}
|
||||
|
||||
def _update_known_chat_templates():
|
||||
try:
|
||||
_chat_templates: Traversable = files(__package__) / "chat_templates"
|
||||
_chat_templates = files(__package__) / "chat_templates"
|
||||
_extra_jinja_templates = {Path(traversable.name).stem: traversable.read_text().replace(' ', '').replace('\n', '') for traversable in _chat_templates.iterdir() if traversable.is_file()}
|
||||
KNOWN_CHAT_TEMPLATES.update(_extra_jinja_templates)
|
||||
except ImportError as exc:
|
||||
|
||||
@ -10,7 +10,7 @@ from transformers.tokenization_utils_base import TextInput, PreTokenizedInput, T
|
||||
from transformers.utils import PaddingStrategy
|
||||
from typing_extensions import TypedDict, NotRequired
|
||||
|
||||
from comfy.component_model.tensor_types import RGBImageBatch
|
||||
from ..component_model.tensor_types import RGBImageBatch
|
||||
|
||||
|
||||
class ProcessorResult(TypedDict):
|
||||
|
||||
@ -23,7 +23,7 @@ from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange
|
||||
from torch import nn
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from ..modules.attention import optimized_attention
|
||||
|
||||
|
||||
def get_normalization(name: str, channels: int, weight_args={}, operations=None):
|
||||
|
||||
@ -11,7 +11,7 @@ import math
|
||||
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
|
||||
from torchvision import transforms
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from ..modules.attention import optimized_attention
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
t: torch.Tensor,
|
||||
|
||||
@ -1,12 +1,6 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from comfy.ldm.flux.layers import (
|
||||
DoubleStreamBlock,
|
||||
LastLayer,
|
||||
MLPEmbedder,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
)
|
||||
from ..flux.layers import DoubleStreamBlock, LastLayer, MLPEmbedder, SingleStreamBlock, timestep_embedding
|
||||
|
||||
|
||||
class Hunyuan3Dv2(nn.Module):
|
||||
|
||||
@ -5,10 +5,10 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from comfy.ldm.modules.diffusionmodules.model import vae_attention
|
||||
from ..modules.diffusionmodules.model import vae_attention
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
from ...ops import disable_weight_init
|
||||
ops = disable_weight_init
|
||||
|
||||
CACHE_T = 2
|
||||
|
||||
|
||||
@ -31,9 +31,9 @@ import psutil
|
||||
import torch
|
||||
from opentelemetry.trace import get_current_span
|
||||
|
||||
from .cmd.main_pre import tracer
|
||||
from . import interruption
|
||||
from .cli_args import args, PerformanceFeature
|
||||
from .cmd.main_pre import tracer
|
||||
from .component_model.deprecation import _deprecate_method
|
||||
from .model_management_types import ModelManageable
|
||||
|
||||
|
||||
@ -11,12 +11,13 @@ from importlib.metadata import entry_points
|
||||
|
||||
from opentelemetry.trace import Span, Status, StatusCode
|
||||
|
||||
from .package_typing import ExportedNodes
|
||||
from ..cmd.main_pre import tracer
|
||||
from .package_typing import ExportedNodes
|
||||
from ..component_model.files import get_package_as_path
|
||||
|
||||
_comfy_nodes: ExportedNodes = ExportedNodes()
|
||||
_nodes_available_at_startup: ExportedNodes = ExportedNodes()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleType):
|
||||
node_class_mappings = getattr(module, 'NODE_CLASS_MAPPINGS', None)
|
||||
@ -55,7 +56,7 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
|
||||
timings.append((time.perf_counter() - time_before, full_name, True, exported_nodes))
|
||||
except Exception as exc:
|
||||
module_decl = None
|
||||
logging.error(f"{full_name} import failed", exc_info=exc)
|
||||
logger.error(f"{full_name} import failed", exc_info=exc)
|
||||
span.set_status(Status(StatusCode.ERROR))
|
||||
span.record_exception(exc)
|
||||
exceptions.append(exc)
|
||||
@ -84,7 +85,7 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
|
||||
potential_path_error: AttributeError = x
|
||||
if potential_path_error.name == '__path__':
|
||||
continue
|
||||
logging.error(f"{full_name} import failed", exc_info=x)
|
||||
logger.error(f"{full_name} import failed", exc_info=x)
|
||||
success = False
|
||||
exceptions.append(x)
|
||||
span.set_status(Status(StatusCode.ERROR))
|
||||
@ -93,7 +94,7 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
|
||||
|
||||
if print_import_times and len(timings) > 0 or any(not success for (_, _, success, _) in timings):
|
||||
for (duration, module_name, success, new_nodes) in sorted(timings):
|
||||
logging.log(logging.DEBUG if success else logging.ERROR, f"{duration:6.1f} seconds{'' if success else ' (IMPORT FAILED)'}, {module_name} ({len(new_nodes)} nodes loaded)")
|
||||
logger.log(logging.DEBUG if success else logging.ERROR, f"{duration:6.1f} seconds{'' if success else ' (IMPORT FAILED)'}, {module_name} ({len(new_nodes)} nodes loaded)")
|
||||
if raise_on_failure and len(exceptions) > 0:
|
||||
try:
|
||||
raise ExceptionGroup("Node import failed", exceptions) # pylint: disable=using-exception-groups-in-unsupported-version
|
||||
@ -105,12 +106,16 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
|
||||
@tracer.start_as_current_span("Import All Nodes In Workspace")
|
||||
def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=False) -> ExportedNodes:
|
||||
# now actually import the nodes, to improve control of node loading order
|
||||
from comfy_extras import nodes as comfy_extras_nodes # pylint: disable=absolute-import-used
|
||||
from ..cli_args import args
|
||||
from . import base_nodes
|
||||
from .vanilla_node_importing import mitigated_import_of_vanilla_custom_nodes
|
||||
|
||||
# only load these nodes once
|
||||
if len(_comfy_nodes) == 0:
|
||||
if len(_nodes_available_at_startup) == 0:
|
||||
|
||||
# import base_nodes first
|
||||
from . import base_nodes
|
||||
from comfy_extras import nodes as comfy_extras_nodes # pylint: disable=absolute-import-used
|
||||
from .vanilla_node_importing import mitigated_import_of_vanilla_custom_nodes
|
||||
|
||||
base_and_extra = reduce(lambda x, y: x.update(y),
|
||||
map(lambda module_inner: _import_and_enumerate_nodes_in_module(module_inner, raise_on_failure=raise_on_failure), [
|
||||
# this is the list of default nodes to import
|
||||
@ -121,9 +126,9 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=Fa
|
||||
custom_nodes_mappings = ExportedNodes()
|
||||
|
||||
if args.disable_all_custom_nodes:
|
||||
logging.info("Loading custom nodes was disabled, only base and extra nodes were loaded")
|
||||
_comfy_nodes.update(base_and_extra)
|
||||
return _comfy_nodes
|
||||
logger.info("Loading custom nodes was disabled, only base and extra nodes were loaded")
|
||||
_nodes_available_at_startup.update(base_and_extra)
|
||||
return _nodes_available_at_startup
|
||||
|
||||
# load from entrypoints
|
||||
for entry_point in entry_points().select(group='comfyui.custom_nodes'):
|
||||
@ -131,7 +136,7 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=Fa
|
||||
try:
|
||||
module = entry_point.load()
|
||||
except ModuleNotFoundError as module_not_found_error:
|
||||
logging.error(f"A module was not found while importing nodes via an entry point: {entry_point}. Please ensure the entry point in setup.py is named correctly", exc_info=module_not_found_error)
|
||||
logger.error(f"A module was not found while importing nodes via an entry point: {entry_point}. Please ensure the entry point in setup.py is named correctly", exc_info=module_not_found_error)
|
||||
continue
|
||||
|
||||
# Ensure that what we've loaded is indeed a module
|
||||
@ -146,5 +151,5 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=Fa
|
||||
# don't allow custom nodes to overwrite base nodes
|
||||
custom_nodes_mappings -= base_and_extra
|
||||
|
||||
_comfy_nodes.update(base_and_extra + custom_nodes_mappings)
|
||||
return _comfy_nodes
|
||||
_nodes_available_at_startup.update(base_and_extra + custom_nodes_mappings)
|
||||
return _nodes_available_at_startup
|
||||
|
||||
@ -7,8 +7,6 @@ from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, T
|
||||
|
||||
from typing_extensions import TypedDict, NotRequired
|
||||
|
||||
from comfy.comfy_types import FileLocator
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
|
||||
@ -131,6 +131,7 @@ dev = [
|
||||
"freezegun",
|
||||
"coverage",
|
||||
"pylint",
|
||||
"astroid",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
@ -24,19 +24,47 @@ class AbsoluteImportChecker(BaseChecker):
|
||||
super().__init__(linter)
|
||||
|
||||
def visit_importfrom(self, node: nodes.ImportFrom) -> None:
|
||||
current_file = node.root().file
|
||||
if current_file is None:
|
||||
"""
|
||||
Check for absolute imports from the same top-level package.
|
||||
|
||||
This method is called for every `from ... import ...` statement.
|
||||
It checks if a module within 'comfy' or 'comfy_extras' packages
|
||||
is using an absolute import from its own package, which should
|
||||
be a relative import instead.
|
||||
|
||||
For example, inside `comfy/nodes/logic.py`, an import like
|
||||
`from comfy.utils import some_function` will be flagged.
|
||||
The preferred way would be `from ..utils import some_function`.
|
||||
"""
|
||||
# An import is relative if its level is greater than 0.
|
||||
# e.g., from . import foo (level=1), from .. import bar (level=2)
|
||||
# We only want to check absolute imports, so we skip relative ones.
|
||||
if node.level and node.level > 0:
|
||||
return
|
||||
|
||||
package_path = os.path.dirname(current_file)
|
||||
package_name = os.path.basename(package_path)
|
||||
# Get the fully qualified name of the module being linted.
|
||||
# For a file at '.../comfy/nodes/common.py', this will be 'comfy.nodes.common'.
|
||||
module_qname = node.root().qname()
|
||||
|
||||
if node.modname.startswith(package_name) and package_name in ['comfy', 'comfy_extras']:
|
||||
import_parts = node.modname.split('.')
|
||||
# `node.modname` is the module name in the `from` statement.
|
||||
# For `from comfy.utils import x`, `modname` is `comfy.utils`.
|
||||
imported_modname = node.modname
|
||||
if not imported_modname:
|
||||
return
|
||||
|
||||
if import_parts[0] == package_name:
|
||||
self.add_message('absolute-import-used', node=node, args=(node.modname,))
|
||||
# We are only interested in modules within 'comfy' or 'comfy_extras'.
|
||||
# We determine this by looking at the first part of the qualified name.
|
||||
current_top_package = module_qname.split('.')[0]
|
||||
if current_top_package not in ['comfy', 'comfy_extras']:
|
||||
return
|
||||
|
||||
imported_top_package = imported_modname.split('.')[0]
|
||||
|
||||
# If the top-level package of the imported module is the same as the
|
||||
# current module's top-level package, it's an internal absolute import.
|
||||
if imported_top_package == current_top_package:
|
||||
self.add_message('absolute-import-used', node=node, args=(imported_modname,))
|
||||
|
||||
|
||||
def register(linter: "PyLinter") -> None:
|
||||
linter.register_checker(AbsoluteImportChecker(linter))
|
||||
linter.register_checker(AbsoluteImportChecker(linter))
|
||||
@ -1,50 +1,25 @@
|
||||
import sys
|
||||
import time
|
||||
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import pathlib
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import urllib
|
||||
from typing import Tuple, List
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
import socket
|
||||
import subprocess
|
||||
import urllib
|
||||
from testcontainers.rabbitmq import RabbitMqContainer
|
||||
from typing import Tuple, List
|
||||
|
||||
from comfy.cli_args_types import Configuration
|
||||
|
||||
logging.getLogger("pika").setLevel(logging.CRITICAL + 1)
|
||||
logging.getLogger("aio_pika").setLevel(logging.CRITICAL + 1)
|
||||
logging.getLogger("testcontainers.core.container").setLevel(logging.WARNING)
|
||||
logging.getLogger("testcontainers.core.waiting_utils").setLevel(logging.WARNING)
|
||||
|
||||
# fixes issues with running the testcontainers rabbitmqcontainer on Windows
|
||||
os.environ["TC_HOST"] = "localhost"
|
||||
|
||||
|
||||
def get_lan_ip():
|
||||
"""
|
||||
Finds the host's IP address on the LAN it's connected to.
|
||||
|
||||
Returns:
|
||||
str: The IP address of the host on the LAN.
|
||||
"""
|
||||
# Create a dummy socket
|
||||
s = None
|
||||
try:
|
||||
# Connect to a dummy address (Here, Google's public DNS server)
|
||||
# The actual connection is not made, but this allows finding out the LAN IP
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
s.connect(("8.8.8.8", 80))
|
||||
ip = s.getsockname()[0]
|
||||
finally:
|
||||
if s is not None:
|
||||
s.close()
|
||||
return ip
|
||||
|
||||
|
||||
def run_server(server_arguments: Configuration):
|
||||
from comfy.cmd.main import main
|
||||
from comfy.cli_args import args
|
||||
@ -95,6 +70,11 @@ def has_gpu() -> bool:
|
||||
@pytest.fixture(scope="module", autouse=False, params=["ThreadPoolExecutor", "ProcessPoolExecutor"])
|
||||
def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers: int = 1):
|
||||
from huggingface_hub import hf_hub_download
|
||||
from testcontainers.rabbitmq import RabbitMqContainer
|
||||
|
||||
logging.getLogger("testcontainers.core.container").setLevel(logging.WARNING)
|
||||
logging.getLogger("testcontainers.core.waiting_utils").setLevel(logging.WARNING)
|
||||
|
||||
hf_hub_download("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors")
|
||||
hf_hub_download("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors")
|
||||
|
||||
@ -108,8 +88,8 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers
|
||||
|
||||
frontend_command = [
|
||||
"comfyui",
|
||||
"--listen=0.0.0.0",
|
||||
"--port=9001",
|
||||
"--listen=127.0.0.1",
|
||||
"--port=19001",
|
||||
"--cpu",
|
||||
"--distributed-queue-frontend",
|
||||
f"-w={str(tmp_path)}",
|
||||
@ -122,7 +102,7 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers
|
||||
for i in range(num_workers):
|
||||
backend_command = [
|
||||
"comfyui-worker",
|
||||
f"--port={9002 + i}",
|
||||
f"--port={19002 + i}",
|
||||
f"-w={str(tmp_path)}",
|
||||
f"--distributed-queue-connection-uri={connection_uri}",
|
||||
f"--executor-factory={executor_factory}"
|
||||
@ -130,7 +110,7 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers
|
||||
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr))
|
||||
|
||||
try:
|
||||
server_address = f"http://{get_lan_ip()}:9001"
|
||||
server_address = f"http://127.0.0.1:19001"
|
||||
start_time = time.time()
|
||||
connected = False
|
||||
while time.time() - start_time < 60:
|
||||
|
||||
@ -15,7 +15,7 @@ from comfy.component_model.executor_types import SendSyncEvent, SendSyncData, Ex
|
||||
DependencyCycleError
|
||||
from comfy.distributed.server_stub import ServerStub
|
||||
from comfy.execution_context import context_add_custom_nodes
|
||||
from comfy.graph_utils import GraphBuilder, Node
|
||||
from comfy_execution.graph_utils import GraphBuilder, Node
|
||||
from comfy.nodes.package_typing import ExportedNodes
|
||||
|
||||
current_test_name = ContextVar('current_test_name', default=None)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from comfy.graph_utils import GraphBuilder, is_link
|
||||
from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||
from comfy.graph import ExecutionBlocker
|
||||
from .tools import VariantSupport
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ import time
|
||||
import asyncio
|
||||
from comfy.utils import ProgressBar
|
||||
from .tools import VariantSupport
|
||||
from comfy.graph_utils import GraphBuilder
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC
|
||||
from comfy.comfy_types import IO
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from comfy.graph_utils import GraphBuilder
|
||||
from comfy_execution.graph_utils import GraphBuilder
|
||||
from .tools import VariantSupport
|
||||
|
||||
@VariantSupport()
|
||||
|
||||
129
tests/main_pre_import_checker.py
Normal file
129
tests/main_pre_import_checker.py
Normal file
@ -0,0 +1,129 @@
|
||||
"""Pylint checker for ensuring main_pre is imported first."""
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import astroid
|
||||
from astroid import nodes
|
||||
from pylint.checkers import BaseChecker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pylint.lint import PyLinter
|
||||
|
||||
|
||||
class MainPreImportOrderChecker(BaseChecker):
|
||||
"""
|
||||
Ensures that imports from 'comfy.cmd.main_pre' or similar setup modules
|
||||
occur before any other relative imports or imports from the 'comfy' package.
|
||||
|
||||
This is important for code that relies on setup being performed by 'main_pre'
|
||||
before other modules from the same package are imported.
|
||||
"""
|
||||
|
||||
name = 'main-pre-import-order'
|
||||
msgs = {
|
||||
'W0002': (
|
||||
'Setup import %s must be placed before other package imports like %s.',
|
||||
'main-pre-import-not-first',
|
||||
"To ensure necessary setup is performed, 'comfy.cmd.main_pre' or similar "
|
||||
"setup imports must precede other relative imports or imports from the "
|
||||
"'comfy' family of packages."
|
||||
),
|
||||
}
|
||||
|
||||
def _is_main_pre_import(self, stmt: Union[nodes.Import, nodes.ImportFrom]) -> bool:
|
||||
"""Checks if an import statement is for 'comfy.cmd.main_pre'."""
|
||||
if isinstance(stmt, nodes.Import):
|
||||
for name, _ in stmt.names:
|
||||
if name == 'comfy.cmd.main_pre' or name.startswith('comfy.cmd.main_pre.'):
|
||||
return True
|
||||
return False
|
||||
|
||||
if isinstance(stmt, nodes.ImportFrom):
|
||||
qname: Optional[str] = None
|
||||
if stmt.level == 0:
|
||||
qname = stmt.modname
|
||||
else:
|
||||
try:
|
||||
# Attempt to resolve the relative import to a fully qualified name
|
||||
imported_module = stmt.do_import_module()
|
||||
qname = imported_module.qname()
|
||||
except astroid.AstroidError:
|
||||
# Fallback for unresolved relative imports, check the literal module name
|
||||
if stmt.modname and stmt.modname.endswith('.main_pre'):
|
||||
return True
|
||||
# Heuristic for `from ..cmd import main_pre` in `comfy/entrypoints/*`
|
||||
if stmt.modname == 'cmd' and stmt.root().qname().startswith('comfy'):
|
||||
for name, _ in stmt.names:
|
||||
if name == 'main_pre':
|
||||
return True
|
||||
|
||||
if not qname:
|
||||
return False
|
||||
|
||||
# from comfy.cmd import main_pre
|
||||
if qname == 'comfy.cmd':
|
||||
for name, _ in stmt.names:
|
||||
if name == 'main_pre':
|
||||
return True
|
||||
|
||||
# from comfy.cmd.main_pre import ... OR from a.b.c.main_pre import ...
|
||||
if qname == 'comfy.cmd.main_pre' or qname.endswith('.main_pre'):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _is_other_relevant_import(self, stmt: Union[nodes.Import, nodes.ImportFrom]) -> bool:
|
||||
"""
|
||||
Checks if an import is a relative import or an import from
|
||||
the 'comfy' package family, and is not a 'main_pre' import.
|
||||
"""
|
||||
if self._is_main_pre_import(stmt):
|
||||
return False
|
||||
|
||||
if isinstance(stmt, nodes.ImportFrom):
|
||||
if stmt.level and stmt.level > 0: # Any relative import
|
||||
return True
|
||||
if stmt.modname and stmt.modname.startswith('comfy'):
|
||||
return True
|
||||
|
||||
if isinstance(stmt, nodes.Import):
|
||||
for name, _ in stmt.names:
|
||||
if name.startswith('comfy'):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def visit_module(self, node: nodes.Module) -> None:
|
||||
"""Checks the order of imports within a module."""
|
||||
imports = [
|
||||
stmt for stmt in node.body
|
||||
if isinstance(stmt, (nodes.Import, nodes.ImportFrom))
|
||||
]
|
||||
|
||||
main_pre_import_node = None
|
||||
for stmt in imports:
|
||||
if self._is_main_pre_import(stmt):
|
||||
main_pre_import_node = stmt
|
||||
break
|
||||
|
||||
# If there's no main_pre import, there's nothing to check.
|
||||
if not main_pre_import_node:
|
||||
return
|
||||
|
||||
for stmt in imports:
|
||||
# We only care about imports that appear before the main_pre import
|
||||
if stmt.fromlineno >= main_pre_import_node.fromlineno:
|
||||
break
|
||||
|
||||
if self._is_other_relevant_import(stmt):
|
||||
self.add_message(
|
||||
'main-pre-import-not-first',
|
||||
node=main_pre_import_node,
|
||||
args=(main_pre_import_node.as_string(), stmt.as_string())
|
||||
)
|
||||
return # Report once per file and exit to avoid spam
|
||||
|
||||
|
||||
def register(linter: "PyLinter") -> None:
|
||||
"""This function is required for a Pylint plugin."""
|
||||
linter.register_checker(MainPreImportOrderChecker(linter))
|
||||
292
tests/unit/prompt_server_test/user_manager_test.py
Normal file
292
tests/unit/prompt_server_test/user_manager_test.py
Normal file
@ -0,0 +1,292 @@
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
from comfy.app.user_manager import UserManager
|
||||
|
||||
pytestmark = (
|
||||
pytest.mark.asyncio
|
||||
) # This applies the asyncio mark to all test functions in the module
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_manager(tmp_path):
|
||||
um = UserManager()
|
||||
um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join(
|
||||
tmp_path, file
|
||||
) if file else tmp_path
|
||||
return um
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(user_manager):
|
||||
app = web.Application()
|
||||
routes = web.RouteTableDef()
|
||||
user_manager.add_routes(routes)
|
||||
app.add_routes(routes)
|
||||
return app
|
||||
|
||||
|
||||
async def test_listuserdata_empty_directory(aiohttp_client, app, tmp_path):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata?dir=test_dir")
|
||||
assert resp.status == 404
|
||||
|
||||
|
||||
async def test_listuserdata_with_files(aiohttp_client, app, tmp_path):
|
||||
os.makedirs(tmp_path / "test_dir")
|
||||
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
|
||||
f.write("test content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata?dir=test_dir")
|
||||
assert resp.status == 200
|
||||
assert await resp.json() == ["file1.txt"]
|
||||
|
||||
|
||||
async def test_listuserdata_recursive(aiohttp_client, app, tmp_path):
|
||||
os.makedirs(tmp_path / "test_dir" / "subdir")
|
||||
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
|
||||
f.write("test content")
|
||||
with open(tmp_path / "test_dir" / "subdir" / "file2.txt", "w") as f:
|
||||
f.write("test content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata?dir=test_dir&recurse=true")
|
||||
assert resp.status == 200
|
||||
assert set(await resp.json()) == {"file1.txt", "subdir/file2.txt"}
|
||||
|
||||
|
||||
async def test_listuserdata_full_info(aiohttp_client, app, tmp_path):
|
||||
os.makedirs(tmp_path / "test_dir")
|
||||
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
|
||||
f.write("test content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata?dir=test_dir&full_info=true")
|
||||
assert resp.status == 200
|
||||
result = await resp.json()
|
||||
assert len(result) == 1
|
||||
assert result[0]["path"] == "file1.txt"
|
||||
assert "size" in result[0]
|
||||
assert "modified" in result[0]
|
||||
|
||||
|
||||
async def test_listuserdata_split_path(aiohttp_client, app, tmp_path):
|
||||
os.makedirs(tmp_path / "test_dir" / "subdir")
|
||||
with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
|
||||
f.write("test content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata?dir=test_dir&recurse=true&split=true")
|
||||
assert resp.status == 200
|
||||
assert await resp.json() == [["subdir/file1.txt", "subdir", "file1.txt"]]
|
||||
|
||||
|
||||
async def test_listuserdata_invalid_directory(aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata?dir=")
|
||||
assert resp.status == 400
|
||||
|
||||
|
||||
async def test_listuserdata_normalized_separator(aiohttp_client, app, tmp_path):
|
||||
os_sep = "\\"
|
||||
with patch("os.sep", os_sep):
|
||||
with patch("os.path.sep", os_sep):
|
||||
os.makedirs(tmp_path / "test_dir" / "subdir")
|
||||
with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
|
||||
f.write("test content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/userdata?dir=test_dir&recurse=true")
|
||||
assert resp.status == 200
|
||||
result = await resp.json()
|
||||
assert len(result) == 1
|
||||
assert "/" in result[0] # Ensure forward slash is used
|
||||
assert "\\" not in result[0] # Ensure backslash is not present
|
||||
assert result[0] == "subdir/file1.txt"
|
||||
|
||||
# Test with full_info
|
||||
resp = await client.get(
|
||||
"/userdata?dir=test_dir&recurse=true&full_info=true"
|
||||
)
|
||||
assert resp.status == 200
|
||||
result = await resp.json()
|
||||
assert len(result) == 1
|
||||
assert "/" in result[0]["path"] # Ensure forward slash is used
|
||||
assert "\\" not in result[0]["path"] # Ensure backslash is not present
|
||||
assert result[0]["path"] == "subdir/file1.txt"
|
||||
|
||||
|
||||
async def test_post_userdata_new_file(aiohttp_client, app, tmp_path):
|
||||
client = await aiohttp_client(app)
|
||||
content = b"test content"
|
||||
resp = await client.post("/userdata/test.txt", data=content)
|
||||
|
||||
assert resp.status == 200
|
||||
assert await resp.text() == '"test.txt"'
|
||||
|
||||
# Verify file was created with correct content
|
||||
with open(tmp_path / "test.txt", "rb") as f:
|
||||
assert f.read() == content
|
||||
|
||||
|
||||
async def test_post_userdata_overwrite_existing(aiohttp_client, app, tmp_path):
|
||||
# Create initial file
|
||||
with open(tmp_path / "test.txt", "w") as f:
|
||||
f.write("initial content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
new_content = b"updated content"
|
||||
resp = await client.post("/userdata/test.txt", data=new_content)
|
||||
|
||||
assert resp.status == 200
|
||||
assert await resp.text() == '"test.txt"'
|
||||
|
||||
# Verify file was overwritten
|
||||
with open(tmp_path / "test.txt", "rb") as f:
|
||||
assert f.read() == new_content
|
||||
|
||||
|
||||
async def test_post_userdata_no_overwrite(aiohttp_client, app, tmp_path):
|
||||
# Create initial file
|
||||
with open(tmp_path / "test.txt", "w") as f:
|
||||
f.write("initial content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post("/userdata/test.txt?overwrite=false", data=b"new content")
|
||||
|
||||
assert resp.status == 409
|
||||
|
||||
# Verify original content unchanged
|
||||
with open(tmp_path / "test.txt", "r") as f:
|
||||
assert f.read() == "initial content"
|
||||
|
||||
|
||||
async def test_post_userdata_full_info(aiohttp_client, app, tmp_path):
|
||||
client = await aiohttp_client(app)
|
||||
content = b"test content"
|
||||
resp = await client.post("/userdata/test.txt?full_info=true", data=content)
|
||||
|
||||
assert resp.status == 200
|
||||
result = await resp.json()
|
||||
assert result["path"] == "test.txt"
|
||||
assert result["size"] == len(content)
|
||||
assert "modified" in result
|
||||
|
||||
|
||||
async def test_move_userdata(aiohttp_client, app, tmp_path):
|
||||
# Create initial file
|
||||
with open(tmp_path / "source.txt", "w") as f:
|
||||
f.write("test content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post("/userdata/source.txt/move/dest.txt")
|
||||
|
||||
assert resp.status == 200
|
||||
assert await resp.text() == '"dest.txt"'
|
||||
|
||||
# Verify file was moved
|
||||
assert not os.path.exists(tmp_path / "source.txt")
|
||||
with open(tmp_path / "dest.txt", "r") as f:
|
||||
assert f.read() == "test content"
|
||||
|
||||
|
||||
async def test_move_userdata_no_overwrite(aiohttp_client, app, tmp_path):
|
||||
# Create source and destination files
|
||||
with open(tmp_path / "source.txt", "w") as f:
|
||||
f.write("source content")
|
||||
with open(tmp_path / "dest.txt", "w") as f:
|
||||
f.write("destination content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post("/userdata/source.txt/move/dest.txt?overwrite=false")
|
||||
|
||||
assert resp.status == 409
|
||||
|
||||
# Verify files remain unchanged
|
||||
with open(tmp_path / "source.txt", "r") as f:
|
||||
assert f.read() == "source content"
|
||||
with open(tmp_path / "dest.txt", "r") as f:
|
||||
assert f.read() == "destination content"
|
||||
|
||||
|
||||
async def test_move_userdata_full_info(aiohttp_client, app, tmp_path):
|
||||
# Create initial file
|
||||
with open(tmp_path / "source.txt", "w") as f:
|
||||
f.write("test content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post("/userdata/source.txt/move/dest.txt?full_info=true")
|
||||
|
||||
assert resp.status == 200
|
||||
result = await resp.json()
|
||||
assert result["path"] == "dest.txt"
|
||||
assert result["size"] == len("test content")
|
||||
assert "modified" in result
|
||||
|
||||
# Verify file was moved
|
||||
assert not os.path.exists(tmp_path / "source.txt")
|
||||
with open(tmp_path / "dest.txt", "r") as f:
|
||||
assert f.read() == "test content"
|
||||
|
||||
|
||||
async def test_listuserdata_v2_empty_root(aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/v2/userdata")
|
||||
assert resp.status == 200
|
||||
assert await resp.json() == []
|
||||
|
||||
|
||||
async def test_listuserdata_v2_nonexistent_subdirectory(aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/v2/userdata?path=does_not_exist")
|
||||
assert resp.status == 404
|
||||
|
||||
|
||||
async def test_listuserdata_v2_default(aiohttp_client, app, tmp_path):
|
||||
os.makedirs(tmp_path / "test_dir" / "subdir")
|
||||
(tmp_path / "test_dir" / "file1.txt").write_text("content")
|
||||
(tmp_path / "test_dir" / "subdir" / "file2.txt").write_text("content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/v2/userdata?path=test_dir")
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
file_paths = {item["path"] for item in data if item["type"] == "file"}
|
||||
assert file_paths == {"test_dir/file1.txt", "test_dir/subdir/file2.txt"}
|
||||
|
||||
|
||||
async def test_listuserdata_v2_normalized_separators(aiohttp_client, app, tmp_path, monkeypatch):
|
||||
# Force backslash as os separator
|
||||
monkeypatch.setattr(os, 'sep', '\\')
|
||||
monkeypatch.setattr(os.path, 'sep', '\\')
|
||||
os.makedirs(tmp_path / "test_dir" / "subdir")
|
||||
(tmp_path / "test_dir" / "subdir" / "file1.txt").write_text("x")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/v2/userdata?path=test_dir")
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
for item in data:
|
||||
assert "/" in item["path"]
|
||||
assert "\\" not in item["path"]
|
||||
|
||||
|
||||
async def test_listuserdata_v2_url_encoded_path(aiohttp_client, app, tmp_path):
|
||||
# Create a directory with a space in its name and a file inside
|
||||
os.makedirs(tmp_path / "my dir")
|
||||
(tmp_path / "my dir" / "file.txt").write_text("content")
|
||||
|
||||
client = await aiohttp_client(app)
|
||||
# Use URL-encoded space in path parameter
|
||||
resp = await client.get("/v2/userdata?path=my%20dir&recurse=false")
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert len(data) == 1
|
||||
entry = data[0]
|
||||
assert entry["name"] == "file.txt"
|
||||
# Ensure the path is correctly decoded and uses forward slash
|
||||
assert entry["path"] == "my dir/file.txt"
|
||||
@ -122,7 +122,7 @@ def test_bool_request_parameter():
|
||||
assert v == True
|
||||
|
||||
|
||||
def test_string_enum_request_parameter():
|
||||
async def test_string_enum_request_parameter():
|
||||
nt = StringEnumRequestParameter.INPUT_TYPES()
|
||||
assert nt is not None
|
||||
n = StringEnumRequestParameter()
|
||||
@ -155,8 +155,9 @@ def test_string_enum_request_parameter():
|
||||
}
|
||||
from comfy.cmd.execution import validate_inputs
|
||||
validated: dict[str, ValidateInputsTuple] = {}
|
||||
validated["1"] = validate_inputs(prompt, "1", validated)
|
||||
validated["2"] = validate_inputs(prompt, "2", validated)
|
||||
prompt_id = str(uuid.uuid4())
|
||||
validated["1"] = await validate_inputs(prompt_id, prompt, "1", validated)
|
||||
validated["2"] = await validate_inputs(prompt_id, prompt, "2", validated)
|
||||
assert validated["2"].valid
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user