mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
packaging fixes
- enable user db - fix main_pre order everywhere - fix absolute to relative imports everywhere - async better supported
This commit is contained in:
parent
c086c5e005
commit
96b4e04315
@ -82,7 +82,7 @@ limit-inference-results=100
|
|||||||
|
|
||||||
# List of plugins (as comma separated values of python module names) to load,
|
# List of plugins (as comma separated values of python module names) to load,
|
||||||
# usually to register additional checkers.
|
# usually to register additional checkers.
|
||||||
load-plugins=tests.absolute_import_checker
|
load-plugins=tests.absolute_import_checker,tests.main_pre_import_checker
|
||||||
|
|
||||||
# Pickle collected data for later comparisons.
|
# Pickle collected data for later comparisons.
|
||||||
persistent=yes
|
persistent=yes
|
||||||
@ -678,7 +678,7 @@ disable=raw-checker-failed,
|
|||||||
# either give multiple identifier separated by comma (,) or put this option
|
# either give multiple identifier separated by comma (,) or put this option
|
||||||
# multiple time (only on the command line, not in the configuration file where
|
# multiple time (only on the command line, not in the configuration file where
|
||||||
# it should appear only once). See also the "--disable" option for examples.
|
# it should appear only once). See also the "--disable" option for examples.
|
||||||
enable=absolute-import-used
|
enable=
|
||||||
|
|
||||||
[METHOD_ARGS]
|
[METHOD_ARGS]
|
||||||
|
|
||||||
|
|||||||
@ -13,7 +13,7 @@ script_location = alembic_db
|
|||||||
|
|
||||||
# sys.path path, will be prepended to sys.path if present.
|
# sys.path path, will be prepended to sys.path if present.
|
||||||
# defaults to the current working directory.
|
# defaults to the current working directory.
|
||||||
prepend_sys_path = .
|
# prepend_sys_path = .
|
||||||
|
|
||||||
# timezone to use when rendering the date within the migration file
|
# timezone to use when rendering the date within the migration file
|
||||||
# as well as the filename.
|
# as well as the filename.
|
||||||
@ -63,7 +63,7 @@ version_path_separator = os
|
|||||||
# are written from script.py.mako
|
# are written from script.py.mako
|
||||||
# output_encoding = utf-8
|
# output_encoding = utf-8
|
||||||
|
|
||||||
sqlalchemy.url = sqlite:///user/comfyui.db
|
sqlalchemy.url = sqlite:///./comfyui.db
|
||||||
|
|
||||||
|
|
||||||
[post_write_hooks]
|
[post_write_hooks]
|
||||||
0
comfy/alembic_db/__init__.py
Normal file
0
comfy/alembic_db/__init__.py
Normal file
@ -1,3 +1,4 @@
|
|||||||
|
# pylint: disable=no-member
|
||||||
from sqlalchemy import engine_from_config
|
from sqlalchemy import engine_from_config
|
||||||
from sqlalchemy import pool
|
from sqlalchemy import pool
|
||||||
|
|
||||||
@ -7,8 +8,7 @@ from alembic import context
|
|||||||
# access to the values within the .ini file in use.
|
# access to the values within the .ini file in use.
|
||||||
config = context.config
|
config = context.config
|
||||||
|
|
||||||
|
from ..app.database.models import Base
|
||||||
from comfy.app.database.models import Base
|
|
||||||
target_metadata = Base.metadata
|
target_metadata = Base.metadata
|
||||||
|
|
||||||
# other values from the config, defined by the needs of env.py,
|
# other values from the config, defined by the needs of env.py,
|
||||||
@ -1,8 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
from importlib.resources import files
|
||||||
|
|
||||||
from ...cli_args import args
|
from ...cli_args import args
|
||||||
|
from ...component_model.files import get_package_as_path
|
||||||
|
|
||||||
Session = None
|
Session = None
|
||||||
|
|
||||||
@ -15,6 +17,7 @@ from sqlalchemy.orm import sessionmaker
|
|||||||
|
|
||||||
_DB_AVAILABLE = True
|
_DB_AVAILABLE = True
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def dependencies_available():
|
def dependencies_available():
|
||||||
"""
|
"""
|
||||||
@ -32,9 +35,8 @@ def can_create_session():
|
|||||||
|
|
||||||
|
|
||||||
def get_alembic_config():
|
def get_alembic_config():
|
||||||
root_path = os.path.join(os.path.dirname(__file__), "../..")
|
config_path = str(files("comfy") / "alembic.ini")
|
||||||
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
|
scripts_path = get_package_as_path("comfy.alembic_db")
|
||||||
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
|
|
||||||
|
|
||||||
config = Config(config_path)
|
config = Config(config_path)
|
||||||
config.set_main_option("script_location", scripts_path)
|
config.set_main_option("script_location", scripts_path)
|
||||||
@ -53,7 +55,7 @@ def get_db_path():
|
|||||||
|
|
||||||
def init_db():
|
def init_db():
|
||||||
db_url = args.database_url
|
db_url = args.database_url
|
||||||
logging.debug(f"Database URL: {db_url}")
|
logger.debug(f"Database URL: {db_url}")
|
||||||
db_path = get_db_path()
|
db_path = get_db_path()
|
||||||
db_exists = os.path.exists(db_path)
|
db_exists = os.path.exists(db_path)
|
||||||
|
|
||||||
@ -70,7 +72,7 @@ def init_db():
|
|||||||
target_rev = script.get_current_head()
|
target_rev = script.get_current_head()
|
||||||
|
|
||||||
if target_rev is None:
|
if target_rev is None:
|
||||||
logging.warning("No target revision found.")
|
logger.debug("No target revision found.")
|
||||||
elif current_rev != target_rev:
|
elif current_rev != target_rev:
|
||||||
# Backup the database pre upgrade
|
# Backup the database pre upgrade
|
||||||
backup_path = db_path + ".bkp"
|
backup_path = db_path + ".bkp"
|
||||||
@ -81,13 +83,13 @@ def init_db():
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
command.upgrade(config, target_rev)
|
command.upgrade(config, target_rev)
|
||||||
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
|
logger.info(f"Database upgraded from {current_rev} to {target_rev}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if backup_path:
|
if backup_path:
|
||||||
# Restore the database from backup if upgrade fails
|
# Restore the database from backup if upgrade fails
|
||||||
shutil.copy(backup_path, db_path)
|
shutil.copy(backup_path, db_path)
|
||||||
os.remove(backup_path)
|
os.remove(backup_path)
|
||||||
logging.exception("Error upgrading database: ")
|
logger.exception("Error upgrading database: ")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
global Session
|
global Session
|
||||||
|
|||||||
@ -5,7 +5,8 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch
|
from ..ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, \
|
||||||
|
get_2d_sincos_pos_embed_torch
|
||||||
|
|
||||||
|
|
||||||
class ControlNetEmbedder(nn.Module):
|
class ControlNetEmbedder(nn.Module):
|
||||||
|
|||||||
@ -11,7 +11,7 @@ import configargparse as argparse
|
|||||||
from . import __version__
|
from . import __version__
|
||||||
from . import options
|
from . import options
|
||||||
from .cli_args_types import LatentPreviewMethod, Configuration, ConfigurationExtender, EnumAction, \
|
from .cli_args_types import LatentPreviewMethod, Configuration, ConfigurationExtender, EnumAction, \
|
||||||
EnhancedConfigArgParser, PerformanceFeature, is_valid_directory
|
EnhancedConfigArgParser, PerformanceFeature, is_valid_directory, db_config
|
||||||
|
|
||||||
# todo: move this
|
# todo: move this
|
||||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||||
@ -261,8 +261,8 @@ def _create_parser() -> EnhancedConfigArgParser:
|
|||||||
help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
|
help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("--database-url", type=str, default=f"sqlite:///:memory:", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
default_db_url = db_config()
|
||||||
|
parser.add_argument("--database-url", type=str, default=default_db_url, help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||||
parser.add_argument("--workflows", type=str, nargs='+', default=[], help="Execute the API workflow(s) specified in the provided files. For each workflow, its outputs will be printed to a line to standard out. Application logging will be redirected to standard error. Use `-` to signify standard in.")
|
parser.add_argument("--workflows", type=str, nargs='+', default=[], help="Execute the API workflow(s) specified in the provided files. For each workflow, its outputs will be printed to a line to standard out. Application logging will be redirected to standard error. Use `-` to signify standard in.")
|
||||||
|
|
||||||
# now give plugins a chance to add configuration
|
# now give plugins a chance to add configuration
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Optional, List, Callable, Any, Union, Mapping, NamedTuple
|
from typing import Optional, List, Callable, Any, Union, Mapping, NamedTuple
|
||||||
|
|
||||||
@ -20,6 +21,22 @@ class LatentPreviewMethod(enum.Enum):
|
|||||||
ConfigObserver = Callable[[str, Any], None]
|
ConfigObserver = Callable[[str, Any], None]
|
||||||
|
|
||||||
|
|
||||||
|
def db_config() -> str:
|
||||||
|
from .vendor.appdirs import user_data_dir
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
try:
|
||||||
|
data_dir = user_data_dir(appname="comfyui")
|
||||||
|
os.makedirs(data_dir, exist_ok=True)
|
||||||
|
db_path = os.path.join(data_dir, "comfy.db")
|
||||||
|
default_db_url = f"sqlite:///{db_path}"
|
||||||
|
except Exception as e:
|
||||||
|
# Fallback to an in-memory database if the user directory can't be accessed
|
||||||
|
logger.warning(f"Could not determine user data directory for database, falling back to in-memory: {e}")
|
||||||
|
default_db_url = "sqlite:///:memory:"
|
||||||
|
return default_db_url
|
||||||
|
|
||||||
|
|
||||||
def is_valid_directory(path: str) -> str:
|
def is_valid_directory(path: str) -> str:
|
||||||
"""Validate if the given path is a directory, and check permissions."""
|
"""Validate if the given path is a directory, and check permissions."""
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
@ -261,7 +278,7 @@ class Configuration(dict):
|
|||||||
self.front_end_version: str = "comfyanonymous/ComfyUI@latest"
|
self.front_end_version: str = "comfyanonymous/ComfyUI@latest"
|
||||||
self.front_end_root: Optional[str] = None
|
self.front_end_root: Optional[str] = None
|
||||||
self.comfy_api_base: str = "https://api.comfy.org"
|
self.comfy_api_base: str = "https://api.comfy.org"
|
||||||
self.database_url: str = "sqlite:///:memory:"
|
self.database_url: str = db_config()
|
||||||
|
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
self[key] = value
|
self[key] = value
|
||||||
|
|||||||
@ -14,11 +14,11 @@ from opentelemetry import context, propagate
|
|||||||
from opentelemetry.context import Context, attach, detach
|
from opentelemetry.context import Context, attach, detach
|
||||||
from opentelemetry.trace import Status, StatusCode
|
from opentelemetry.trace import Status, StatusCode
|
||||||
|
|
||||||
|
from ..cmd.main_pre import tracer
|
||||||
from .client_types import V1QueuePromptResponse
|
from .client_types import V1QueuePromptResponse
|
||||||
from ..api.components.schema.prompt import PromptDict
|
from ..api.components.schema.prompt import PromptDict
|
||||||
from ..cli_args_types import Configuration
|
from ..cli_args_types import Configuration
|
||||||
from ..cmd.folder_paths import init_default_paths # pylint: disable=import-error
|
from ..cmd.folder_paths import init_default_paths # pylint: disable=import-error
|
||||||
from ..cmd.main_pre import tracer
|
|
||||||
from ..component_model.executor_types import ExecutorToClientProgress
|
from ..component_model.executor_types import ExecutorToClientProgress
|
||||||
from ..component_model.make_mutable import make_mutable
|
from ..component_model.make_mutable import make_mutable
|
||||||
from ..distributed.executors import ContextVarExecutor
|
from ..distributed.executors import ContextVarExecutor
|
||||||
@ -97,7 +97,7 @@ async def __execute_prompt(
|
|||||||
else:
|
else:
|
||||||
prompt_executor.server = progress_handler
|
prompt_executor.server = progress_handler
|
||||||
|
|
||||||
prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id},
|
await prompt_executor.execute_async(prompt_mut, prompt_id, {"client_id": client_id},
|
||||||
execute_outputs=validation_tuple.good_output_node_ids)
|
execute_outputs=validation_tuple.good_output_node_ids)
|
||||||
return prompt_executor.outputs_ui
|
return prompt_executor.outputs_ui
|
||||||
except Exception as exc_info:
|
except Exception as exc_info:
|
||||||
@ -195,7 +195,7 @@ class Comfy:
|
|||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
prompt = json.loads(prompt)
|
prompt = json.loads(prompt)
|
||||||
if isinstance(prompt, dict):
|
if isinstance(prompt, dict):
|
||||||
from comfy.api.components.schema.prompt import Prompt
|
from ..api.components.schema.prompt import Prompt
|
||||||
prompt = Prompt.validate(prompt)
|
prompt = Prompt.validate(prompt)
|
||||||
outputs = await self.queue_prompt(prompt)
|
outputs = await self.queue_prompt(prompt)
|
||||||
return V1QueuePromptResponse(urls=[], outputs=outputs)
|
return V1QueuePromptResponse(urls=[], outputs=outputs)
|
||||||
|
|||||||
@ -19,14 +19,15 @@ from typing import List, Optional, Tuple, Literal
|
|||||||
import torch
|
import torch
|
||||||
from opentelemetry.trace import get_current_span, StatusCode, Status
|
from opentelemetry.trace import get_current_span, StatusCode, Status
|
||||||
|
|
||||||
|
# order matters
|
||||||
|
from .main_pre import tracer
|
||||||
|
|
||||||
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, \
|
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, \
|
||||||
DependencyAwareCache, \
|
DependencyAwareCache, \
|
||||||
BasicCache
|
BasicCache
|
||||||
# order matters
|
|
||||||
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||||
from comfy_execution.graph_utils import is_link, GraphBuilder
|
from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||||
from comfy_execution.utils import CurrentNodeContext
|
from comfy_execution.utils import CurrentNodeContext
|
||||||
from .main_pre import tracer
|
|
||||||
from .. import interruption
|
from .. import interruption
|
||||||
from .. import model_management
|
from .. import model_management
|
||||||
from ..cli_args import args
|
from ..cli_args import args
|
||||||
@ -37,7 +38,7 @@ from ..component_model.executor_types import ExecutorToClientProgress, Validatio
|
|||||||
HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage
|
HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage
|
||||||
from ..component_model.files import canonicalize_path
|
from ..component_model.files import canonicalize_path
|
||||||
from ..component_model.module_property import create_module_properties
|
from ..component_model.module_property import create_module_properties
|
||||||
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
|
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus, ExecutionStatusAsDict
|
||||||
from ..execution_context import context_execute_node, context_execute_prompt
|
from ..execution_context import context_execute_node, context_execute_prompt
|
||||||
from ..execution_ext import should_panic_on_exception
|
from ..execution_ext import should_panic_on_exception
|
||||||
from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
|
from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
|
||||||
@ -388,7 +389,7 @@ def format_value(x) -> FormattedValue:
|
|||||||
return str(x.__class__)
|
return str(x.__class__)
|
||||||
|
|
||||||
|
|
||||||
async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches, _node_id: str, extra_data: dict, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple:
|
async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches, node_id: str, extra_data: dict, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
:param server:
|
:param server:
|
||||||
@ -402,8 +403,8 @@ async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, ca
|
|||||||
:param pending_subgraph_results:
|
:param pending_subgraph_results:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
with context_execute_node(_node_id):
|
with context_execute_node(node_id):
|
||||||
return _execute(server, dynprompt, caches, _node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
|
return await _execute(server, dynprompt, caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
|
||||||
|
|
||||||
|
|
||||||
async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple:
|
async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple:
|
||||||
@ -516,7 +517,7 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
|
|||||||
unblock()
|
unblock()
|
||||||
|
|
||||||
asyncio.create_task(await_completion())
|
asyncio.create_task(await_completion())
|
||||||
return (ExecutionResult.PENDING, None, None)
|
return RecursiveExecutionTuple(ExecutionResult.PENDING, None, None)
|
||||||
if len(output_ui) > 0:
|
if len(output_ui) > 0:
|
||||||
caches.ui.set(unique_id, {
|
caches.ui.set(unique_id, {
|
||||||
"meta": {
|
"meta": {
|
||||||
@ -685,11 +686,6 @@ class PromptExecutor:
|
|||||||
if ex is not None and self.raise_exceptions:
|
if ex is not None and self.raise_exceptions:
|
||||||
raise ex
|
raise ex
|
||||||
|
|
||||||
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
|
||||||
asyncio_loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(asyncio_loop)
|
|
||||||
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
|
|
||||||
|
|
||||||
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||||
# torchao and potentially other optimization approaches break when the models are created in inference mode
|
# torchao and potentially other optimization approaches break when the models are created in inference mode
|
||||||
# todo: this should really be backpropagated to code which creates ModelPatchers via lazy evaluation rather than globally checked here
|
# todo: this should really be backpropagated to code which creates ModelPatchers via lazy evaluation rather than globally checked here
|
||||||
@ -1206,7 +1202,7 @@ class PromptQueue(AbstractPromptQueue):
|
|||||||
self.server.queue_updated()
|
self.server.queue_updated()
|
||||||
return copy.deepcopy(item_with_future.queue_tuple), task_id
|
return copy.deepcopy(item_with_future.queue_tuple), task_id
|
||||||
|
|
||||||
def task_done(self, item_id: str, outputs: dict,
|
def task_done(self, item_id: str, outputs: HistoryResultDict,
|
||||||
status: Optional[ExecutionStatus]):
|
status: Optional[ExecutionStatus]):
|
||||||
history_result = outputs
|
history_result = outputs
|
||||||
with self.mutex:
|
with self.mutex:
|
||||||
@ -1215,9 +1211,9 @@ class PromptQueue(AbstractPromptQueue):
|
|||||||
if len(self.history) > MAXIMUM_HISTORY_SIZE:
|
if len(self.history) > MAXIMUM_HISTORY_SIZE:
|
||||||
self.history.pop(next(iter(self.history)))
|
self.history.pop(next(iter(self.history)))
|
||||||
|
|
||||||
status_dict: Optional[dict] = None
|
status_dict = None
|
||||||
if status is not None:
|
if status is not None:
|
||||||
status_dict = copy.deepcopy(ExecutionStatus(*status)._asdict())
|
status_dict: Optional[ExecutionStatusAsDict] = status.as_dict()
|
||||||
|
|
||||||
outputs_ = history_result["outputs"]
|
outputs_ = history_result["outputs"]
|
||||||
# Remove sensitive data from extra_data before storing in history
|
# Remove sensitive data from extra_data before storing in history
|
||||||
@ -1225,11 +1221,13 @@ class PromptQueue(AbstractPromptQueue):
|
|||||||
if sensitive_val in prompt[3]:
|
if sensitive_val in prompt[3]:
|
||||||
prompt[3].pop(sensitive_val)
|
prompt[3].pop(sensitive_val)
|
||||||
|
|
||||||
self.history[prompt[1]] = {
|
history_entry: HistoryEntry = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"outputs": copy.deepcopy(outputs_),
|
"outputs": copy.deepcopy(outputs_),
|
||||||
'status': status_dict,
|
|
||||||
}
|
}
|
||||||
|
if status_dict is not None:
|
||||||
|
history_entry["status"] = status_dict
|
||||||
|
self.history[prompt[1]] = history_entry
|
||||||
self.history[prompt[1]].update(history_result)
|
self.history[prompt[1]].update(history_result)
|
||||||
self.server.queue_updated()
|
self.server.queue_updated()
|
||||||
if queue_item.completed:
|
if queue_item.completed:
|
||||||
|
|||||||
@ -10,20 +10,22 @@ import time
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from comfy.component_model.entrypoints_common import configure_application_paths, executor_from_args
|
# main_pre must be the earliest import
|
||||||
|
from .main_pre import args
|
||||||
|
|
||||||
from . import hook_breaker_ac10a0
|
from . import hook_breaker_ac10a0
|
||||||
from .extra_model_paths import load_extra_path_config
|
from .extra_model_paths import load_extra_path_config
|
||||||
# main_pre must be the earliest import since it suppresses some spurious warnings
|
|
||||||
from .main_pre import args
|
|
||||||
from .. import model_management
|
from .. import model_management
|
||||||
from ..analytics.analytics import initialize_event_tracking
|
from ..analytics.analytics import initialize_event_tracking
|
||||||
from ..cmd import cuda_malloc
|
from ..cmd import cuda_malloc
|
||||||
from ..cmd import folder_paths
|
from ..cmd import folder_paths
|
||||||
from ..cmd import server as server_module
|
from ..cmd import server as server_module
|
||||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||||
|
from ..component_model.entrypoints_common import configure_application_paths, executor_from_args
|
||||||
from ..distributed.distributed_prompt_queue import DistributedPromptQueue
|
from ..distributed.distributed_prompt_queue import DistributedPromptQueue
|
||||||
from ..distributed.server_stub import ServerStub
|
from ..distributed.server_stub import ServerStub
|
||||||
from ..nodes.package import import_all_nodes_in_workspace
|
from ..nodes.package import import_all_nodes_in_workspace
|
||||||
|
from ..nodes_context import get_nodes
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -42,6 +44,10 @@ def cuda_malloc_warning():
|
|||||||
|
|
||||||
|
|
||||||
def prompt_worker(q: AbstractPromptQueue, server_instance: server_module.PromptServer):
|
def prompt_worker(q: AbstractPromptQueue, server_instance: server_module.PromptServer):
|
||||||
|
asyncio.run(_prompt_worker(q, server_instance))
|
||||||
|
|
||||||
|
|
||||||
|
async def _prompt_worker(q: AbstractPromptQueue, server_instance: server_module.PromptServer):
|
||||||
from ..cmd import execution
|
from ..cmd import execution
|
||||||
from ..component_model import queue_types
|
from ..component_model import queue_types
|
||||||
from .. import model_management
|
from .. import model_management
|
||||||
@ -68,7 +74,7 @@ def prompt_worker(q: AbstractPromptQueue, server_instance: server_module.PromptS
|
|||||||
prompt_id = item[1]
|
prompt_id = item[1]
|
||||||
server_instance.last_prompt_id = prompt_id
|
server_instance.last_prompt_id = prompt_id
|
||||||
|
|
||||||
e.execute(item[2], prompt_id, item[3], item[4])
|
await e.execute_async(item[2], prompt_id, item[3], item[4])
|
||||||
need_gc = True
|
need_gc = True
|
||||||
q.task_done(item_id,
|
q.task_done(item_id,
|
||||||
e.history_result,
|
e.history_result,
|
||||||
@ -174,17 +180,16 @@ async def _start_comfyui(from_script_dir: Optional[Path] = None):
|
|||||||
for config_path in itertools.chain(*args.extra_model_paths_config):
|
for config_path in itertools.chain(*args.extra_model_paths_config):
|
||||||
load_extra_path_config(config_path)
|
load_extra_path_config(config_path)
|
||||||
|
|
||||||
# always create directories when started interactively
|
|
||||||
folder_paths.create_directories()
|
|
||||||
if args.create_directories:
|
if args.create_directories:
|
||||||
|
# then, import and exit
|
||||||
import_all_nodes_in_workspace(raise_on_failure=False)
|
import_all_nodes_in_workspace(raise_on_failure=False)
|
||||||
folder_paths.create_directories()
|
folder_paths.create_directories()
|
||||||
exit(0)
|
exit(0)
|
||||||
|
elif args.quick_test_for_ci:
|
||||||
setup_database()
|
import_all_nodes_in_workspace(raise_on_failure=True)
|
||||||
|
exit(0)
|
||||||
|
|
||||||
if args.windows_standalone_build:
|
if args.windows_standalone_build:
|
||||||
folder_paths.create_directories()
|
|
||||||
try:
|
try:
|
||||||
from . import new_updater
|
from . import new_updater
|
||||||
new_updater.update_windows_updater()
|
new_updater.update_windows_updater()
|
||||||
@ -198,7 +203,7 @@ async def _start_comfyui(from_script_dir: Optional[Path] = None):
|
|||||||
|
|
||||||
# at this stage, it's safe to import nodes
|
# at this stage, it's safe to import nodes
|
||||||
hook_breaker_ac10a0.save_functions()
|
hook_breaker_ac10a0.save_functions()
|
||||||
server.nodes = import_all_nodes_in_workspace()
|
server.nodes = get_nodes()
|
||||||
hook_breaker_ac10a0.restore_functions()
|
hook_breaker_ac10a0.restore_functions()
|
||||||
# as a side effect, this also populates the nodes for execution
|
# as a side effect, this also populates the nodes for execution
|
||||||
|
|
||||||
@ -221,6 +226,7 @@ async def _start_comfyui(from_script_dir: Optional[Path] = None):
|
|||||||
|
|
||||||
server.add_routes()
|
server.add_routes()
|
||||||
cuda_malloc_warning()
|
cuda_malloc_warning()
|
||||||
|
setup_database()
|
||||||
|
|
||||||
# in a distributed setting, the default prompt worker will not be able to send execution events via the websocket
|
# in a distributed setting, the default prompt worker will not be able to send execution events via the websocket
|
||||||
worker_thread_server = server if not distributed else ServerStub()
|
worker_thread_server = server if not distributed else ServerStub()
|
||||||
@ -254,14 +260,7 @@ async def _start_comfyui(from_script_dir: Optional[Path] = None):
|
|||||||
logger.debug(f"Setting input directory to: {input_dir}")
|
logger.debug(f"Setting input directory to: {input_dir}")
|
||||||
folder_paths.set_input_directory(input_dir)
|
folder_paths.set_input_directory(input_dir)
|
||||||
|
|
||||||
if args.quick_test_for_ci:
|
# now that nodes are loaded, create directories
|
||||||
# for CI purposes, try importing all the nodes
|
|
||||||
import_all_nodes_in_workspace(raise_on_failure=True)
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
# we no longer lazily load nodes. we'll do it now for the sake of creating directories
|
|
||||||
import_all_nodes_in_workspace(raise_on_failure=False)
|
|
||||||
# now that nodes are loaded, create more directories if appropriate
|
|
||||||
folder_paths.create_directories()
|
folder_paths.create_directories()
|
||||||
|
|
||||||
if len(args.workflows) > 0:
|
if len(args.workflows) > 0:
|
||||||
|
|||||||
@ -1169,7 +1169,7 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
await runner.setup()
|
await runner.setup()
|
||||||
|
|
||||||
if 'tls_keyfile' in args or 'tls_certfile' in args:
|
if 'tls_keyfile' in args or 'tls_certfile' in args:
|
||||||
raise ValueError("Use caddy instead of aiohttp to serve https by setting up a reverse proxy. See README.md")
|
logger.warning("Use caddy instead of aiohttp to serve https by setting up a reverse proxy. See README.md")
|
||||||
|
|
||||||
def is_ipv4(address: str, *args):
|
def is_ipv4(address: str, *args):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import typing
|
import typing
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
|
|
||||||
|
from .executor_types import HistoryResultDict
|
||||||
from .queue_types import QueueTuple, HistoryEntry, QueueItem, Flags, ExecutionStatus, TaskInvocation, AbstractPromptQueueGetCurrentQueueItems
|
from .queue_types import QueueTuple, HistoryEntry, QueueItem, Flags, ExecutionStatus, TaskInvocation, AbstractPromptQueueGetCurrentQueueItems
|
||||||
|
|
||||||
|
|
||||||
@ -42,7 +43,7 @@ class AbstractPromptQueue(metaclass=ABCMeta):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def task_done(self, item_id: str, outputs: dict,
|
def task_done(self, item_id: str, outputs: HistoryResultDict,
|
||||||
status: typing.Optional[ExecutionStatus]):
|
status: typing.Optional[ExecutionStatus]):
|
||||||
"""
|
"""
|
||||||
Signals to the user interface that the task with the specified id is completed
|
Signals to the user interface that the task with the specified id is completed
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import copy
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import NamedTuple, Optional, List, Literal, Sequence
|
from typing import NamedTuple, Optional, List, Literal, Sequence
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
@ -24,6 +25,13 @@ class ExecutionStatus(NamedTuple):
|
|||||||
completed: bool
|
completed: bool
|
||||||
messages: List[str]
|
messages: List[str]
|
||||||
|
|
||||||
|
def as_dict(self) -> ExecutionStatusAsDict:
|
||||||
|
return {
|
||||||
|
"status_str": self.status_str,
|
||||||
|
"completed": self.completed,
|
||||||
|
"messages": copy.copy(self.messages),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ExecutionError(RuntimeError):
|
class ExecutionError(RuntimeError):
|
||||||
def __init__(self, task_id: int | str, status: Optional[ExecutionStatus] = None, exceptions: Optional[Sequence[Exception]] = None, *args):
|
def __init__(self, task_id: int | str, status: Optional[ExecutionStatus] = None, exceptions: Optional[Sequence[Exception]] = None, *args):
|
||||||
|
|||||||
@ -13,12 +13,12 @@ from aio_pika import connect_robust
|
|||||||
from aio_pika.abc import AbstractConnection, AbstractChannel
|
from aio_pika.abc import AbstractConnection, AbstractChannel
|
||||||
from aio_pika.patterns import JsonRPC
|
from aio_pika.patterns import JsonRPC
|
||||||
|
|
||||||
|
from ..cmd.main_pre import tracer
|
||||||
from .distributed_progress import ProgressHandlers
|
from .distributed_progress import ProgressHandlers
|
||||||
from .distributed_types import RpcRequest, RpcReply
|
from .distributed_types import RpcRequest, RpcReply
|
||||||
from .history import History
|
from .history import History
|
||||||
from .server_stub import ServerStub
|
from .server_stub import ServerStub
|
||||||
from ..auth.permissions import jwt_decode
|
from ..auth.permissions import jwt_decode
|
||||||
from ..cmd.main_pre import tracer
|
|
||||||
from ..cmd.server import PromptServer
|
from ..cmd.server import PromptServer
|
||||||
from ..component_model.abstract_prompt_queue import AsyncAbstractPromptQueue, AbstractPromptQueue
|
from ..component_model.abstract_prompt_queue import AsyncAbstractPromptQueue, AbstractPromptQueue
|
||||||
from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData, HistoryResultDict
|
from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData, HistoryResultDict
|
||||||
|
|||||||
@ -10,12 +10,12 @@ from aio_pika.patterns import JsonRPC
|
|||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from aiormq import AMQPConnectionError
|
from aiormq import AMQPConnectionError
|
||||||
|
|
||||||
|
from ..cmd.main_pre import tracer
|
||||||
from .executors import ContextVarExecutor
|
from .executors import ContextVarExecutor
|
||||||
from .distributed_progress import DistributedExecutorToClientProgress
|
from .distributed_progress import DistributedExecutorToClientProgress
|
||||||
from .distributed_types import RpcRequest, RpcReply
|
from .distributed_types import RpcRequest, RpcReply
|
||||||
from .process_pool_executor import ProcessPoolExecutor
|
from .process_pool_executor import ProcessPoolExecutor
|
||||||
from ..client.embedded_comfy_client import Comfy
|
from ..client.embedded_comfy_client import Comfy
|
||||||
from ..cmd.main_pre import tracer
|
|
||||||
from ..component_model.queue_types import ExecutionStatus
|
from ..component_model.queue_types import ExecutionStatus
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@ -2,7 +2,6 @@ import asyncio
|
|||||||
|
|
||||||
from ..cmd.main_pre import args
|
from ..cmd.main_pre import args
|
||||||
from ..component_model.entrypoints_common import configure_application_paths, executor_from_args
|
from ..component_model.entrypoints_common import configure_application_paths, executor_from_args
|
||||||
from ..distributed.executors import ContextVarExecutor, ContextVarProcessPoolExecutor
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
|
|||||||
@ -1,7 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
from comfy.text_encoders.bert import BertAttention
|
|
||||||
import comfy.model_management
|
from ..ldm.modules.attention import optimized_attention_for_device
|
||||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
from ..model_management import cast_to_device
|
||||||
|
from ..text_encoders.bert import BertAttention
|
||||||
|
|
||||||
|
|
||||||
class Dino2AttentionOutput(torch.nn.Module):
|
class Dino2AttentionOutput(torch.nn.Module):
|
||||||
@ -29,7 +30,7 @@ class LayerScale(torch.nn.Module):
|
|||||||
self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
|
return x * cast_to_device(self.lambda1, x.device, x.dtype)
|
||||||
|
|
||||||
|
|
||||||
class SwiGLUFFN(torch.nn.Module):
|
class SwiGLUFFN(torch.nn.Module):
|
||||||
@ -117,7 +118,7 @@ class Dino2Embeddings(torch.nn.Module):
|
|||||||
x = self.patch_embeddings(pixel_values)
|
x = self.patch_embeddings(pixel_values)
|
||||||
# TODO: mask_token?
|
# TODO: mask_token?
|
||||||
x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1)
|
x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1)
|
||||||
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
|
x = x + cast_to_device(self.position_embeddings, x.device, x.dtype)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from importlib.abc import Traversable # pylint: disable=no-name-in-module
|
|
||||||
from importlib.resources import files
|
from importlib.resources import files
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -10,7 +9,7 @@ KNOWN_CHAT_TEMPLATES = {}
|
|||||||
|
|
||||||
def _update_known_chat_templates():
|
def _update_known_chat_templates():
|
||||||
try:
|
try:
|
||||||
_chat_templates: Traversable = files(__package__) / "chat_templates"
|
_chat_templates = files(__package__) / "chat_templates"
|
||||||
_extra_jinja_templates = {Path(traversable.name).stem: traversable.read_text().replace(' ', '').replace('\n', '') for traversable in _chat_templates.iterdir() if traversable.is_file()}
|
_extra_jinja_templates = {Path(traversable.name).stem: traversable.read_text().replace(' ', '').replace('\n', '') for traversable in _chat_templates.iterdir() if traversable.is_file()}
|
||||||
KNOWN_CHAT_TEMPLATES.update(_extra_jinja_templates)
|
KNOWN_CHAT_TEMPLATES.update(_extra_jinja_templates)
|
||||||
except ImportError as exc:
|
except ImportError as exc:
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from transformers.tokenization_utils_base import TextInput, PreTokenizedInput, T
|
|||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
from typing_extensions import TypedDict, NotRequired
|
from typing_extensions import TypedDict, NotRequired
|
||||||
|
|
||||||
from comfy.component_model.tensor_types import RGBImageBatch
|
from ..component_model.tensor_types import RGBImageBatch
|
||||||
|
|
||||||
|
|
||||||
class ProcessorResult(TypedDict):
|
class ProcessorResult(TypedDict):
|
||||||
|
|||||||
@ -23,7 +23,7 @@ from einops import rearrange, repeat
|
|||||||
from einops.layers.torch import Rearrange
|
from einops.layers.torch import Rearrange
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from ..modules.attention import optimized_attention
|
||||||
|
|
||||||
|
|
||||||
def get_normalization(name: str, channels: int, weight_args={}, operations=None):
|
def get_normalization(name: str, channels: int, weight_args={}, operations=None):
|
||||||
|
|||||||
@ -11,7 +11,7 @@ import math
|
|||||||
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
|
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from ..modules.attention import optimized_attention
|
||||||
|
|
||||||
def apply_rotary_pos_emb(
|
def apply_rotary_pos_emb(
|
||||||
t: torch.Tensor,
|
t: torch.Tensor,
|
||||||
|
|||||||
@ -1,12 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from comfy.ldm.flux.layers import (
|
from ..flux.layers import DoubleStreamBlock, LastLayer, MLPEmbedder, SingleStreamBlock, timestep_embedding
|
||||||
DoubleStreamBlock,
|
|
||||||
LastLayer,
|
|
||||||
MLPEmbedder,
|
|
||||||
SingleStreamBlock,
|
|
||||||
timestep_embedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Hunyuan3Dv2(nn.Module):
|
class Hunyuan3Dv2(nn.Module):
|
||||||
|
|||||||
@ -5,10 +5,10 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from comfy.ldm.modules.diffusionmodules.model import vae_attention
|
from ..modules.diffusionmodules.model import vae_attention
|
||||||
|
|
||||||
import comfy.ops
|
from ...ops import disable_weight_init
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = disable_weight_init
|
||||||
|
|
||||||
CACHE_T = 2
|
CACHE_T = 2
|
||||||
|
|
||||||
|
|||||||
@ -31,9 +31,9 @@ import psutil
|
|||||||
import torch
|
import torch
|
||||||
from opentelemetry.trace import get_current_span
|
from opentelemetry.trace import get_current_span
|
||||||
|
|
||||||
|
from .cmd.main_pre import tracer
|
||||||
from . import interruption
|
from . import interruption
|
||||||
from .cli_args import args, PerformanceFeature
|
from .cli_args import args, PerformanceFeature
|
||||||
from .cmd.main_pre import tracer
|
|
||||||
from .component_model.deprecation import _deprecate_method
|
from .component_model.deprecation import _deprecate_method
|
||||||
from .model_management_types import ModelManageable
|
from .model_management_types import ModelManageable
|
||||||
|
|
||||||
|
|||||||
@ -11,12 +11,13 @@ from importlib.metadata import entry_points
|
|||||||
|
|
||||||
from opentelemetry.trace import Span, Status, StatusCode
|
from opentelemetry.trace import Span, Status, StatusCode
|
||||||
|
|
||||||
from .package_typing import ExportedNodes
|
|
||||||
from ..cmd.main_pre import tracer
|
from ..cmd.main_pre import tracer
|
||||||
|
from .package_typing import ExportedNodes
|
||||||
from ..component_model.files import get_package_as_path
|
from ..component_model.files import get_package_as_path
|
||||||
|
|
||||||
_comfy_nodes: ExportedNodes = ExportedNodes()
|
_nodes_available_at_startup: ExportedNodes = ExportedNodes()
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleType):
|
def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleType):
|
||||||
node_class_mappings = getattr(module, 'NODE_CLASS_MAPPINGS', None)
|
node_class_mappings = getattr(module, 'NODE_CLASS_MAPPINGS', None)
|
||||||
@ -55,7 +56,7 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
|
|||||||
timings.append((time.perf_counter() - time_before, full_name, True, exported_nodes))
|
timings.append((time.perf_counter() - time_before, full_name, True, exported_nodes))
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
module_decl = None
|
module_decl = None
|
||||||
logging.error(f"{full_name} import failed", exc_info=exc)
|
logger.error(f"{full_name} import failed", exc_info=exc)
|
||||||
span.set_status(Status(StatusCode.ERROR))
|
span.set_status(Status(StatusCode.ERROR))
|
||||||
span.record_exception(exc)
|
span.record_exception(exc)
|
||||||
exceptions.append(exc)
|
exceptions.append(exc)
|
||||||
@ -84,7 +85,7 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
|
|||||||
potential_path_error: AttributeError = x
|
potential_path_error: AttributeError = x
|
||||||
if potential_path_error.name == '__path__':
|
if potential_path_error.name == '__path__':
|
||||||
continue
|
continue
|
||||||
logging.error(f"{full_name} import failed", exc_info=x)
|
logger.error(f"{full_name} import failed", exc_info=x)
|
||||||
success = False
|
success = False
|
||||||
exceptions.append(x)
|
exceptions.append(x)
|
||||||
span.set_status(Status(StatusCode.ERROR))
|
span.set_status(Status(StatusCode.ERROR))
|
||||||
@ -93,7 +94,7 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
|
|||||||
|
|
||||||
if print_import_times and len(timings) > 0 or any(not success for (_, _, success, _) in timings):
|
if print_import_times and len(timings) > 0 or any(not success for (_, _, success, _) in timings):
|
||||||
for (duration, module_name, success, new_nodes) in sorted(timings):
|
for (duration, module_name, success, new_nodes) in sorted(timings):
|
||||||
logging.log(logging.DEBUG if success else logging.ERROR, f"{duration:6.1f} seconds{'' if success else ' (IMPORT FAILED)'}, {module_name} ({len(new_nodes)} nodes loaded)")
|
logger.log(logging.DEBUG if success else logging.ERROR, f"{duration:6.1f} seconds{'' if success else ' (IMPORT FAILED)'}, {module_name} ({len(new_nodes)} nodes loaded)")
|
||||||
if raise_on_failure and len(exceptions) > 0:
|
if raise_on_failure and len(exceptions) > 0:
|
||||||
try:
|
try:
|
||||||
raise ExceptionGroup("Node import failed", exceptions) # pylint: disable=using-exception-groups-in-unsupported-version
|
raise ExceptionGroup("Node import failed", exceptions) # pylint: disable=using-exception-groups-in-unsupported-version
|
||||||
@ -105,12 +106,16 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
|
|||||||
@tracer.start_as_current_span("Import All Nodes In Workspace")
|
@tracer.start_as_current_span("Import All Nodes In Workspace")
|
||||||
def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=False) -> ExportedNodes:
|
def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=False) -> ExportedNodes:
|
||||||
# now actually import the nodes, to improve control of node loading order
|
# now actually import the nodes, to improve control of node loading order
|
||||||
from comfy_extras import nodes as comfy_extras_nodes # pylint: disable=absolute-import-used
|
|
||||||
from ..cli_args import args
|
from ..cli_args import args
|
||||||
from . import base_nodes
|
|
||||||
from .vanilla_node_importing import mitigated_import_of_vanilla_custom_nodes
|
|
||||||
# only load these nodes once
|
# only load these nodes once
|
||||||
if len(_comfy_nodes) == 0:
|
if len(_nodes_available_at_startup) == 0:
|
||||||
|
|
||||||
|
# import base_nodes first
|
||||||
|
from . import base_nodes
|
||||||
|
from comfy_extras import nodes as comfy_extras_nodes # pylint: disable=absolute-import-used
|
||||||
|
from .vanilla_node_importing import mitigated_import_of_vanilla_custom_nodes
|
||||||
|
|
||||||
base_and_extra = reduce(lambda x, y: x.update(y),
|
base_and_extra = reduce(lambda x, y: x.update(y),
|
||||||
map(lambda module_inner: _import_and_enumerate_nodes_in_module(module_inner, raise_on_failure=raise_on_failure), [
|
map(lambda module_inner: _import_and_enumerate_nodes_in_module(module_inner, raise_on_failure=raise_on_failure), [
|
||||||
# this is the list of default nodes to import
|
# this is the list of default nodes to import
|
||||||
@ -121,9 +126,9 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=Fa
|
|||||||
custom_nodes_mappings = ExportedNodes()
|
custom_nodes_mappings = ExportedNodes()
|
||||||
|
|
||||||
if args.disable_all_custom_nodes:
|
if args.disable_all_custom_nodes:
|
||||||
logging.info("Loading custom nodes was disabled, only base and extra nodes were loaded")
|
logger.info("Loading custom nodes was disabled, only base and extra nodes were loaded")
|
||||||
_comfy_nodes.update(base_and_extra)
|
_nodes_available_at_startup.update(base_and_extra)
|
||||||
return _comfy_nodes
|
return _nodes_available_at_startup
|
||||||
|
|
||||||
# load from entrypoints
|
# load from entrypoints
|
||||||
for entry_point in entry_points().select(group='comfyui.custom_nodes'):
|
for entry_point in entry_points().select(group='comfyui.custom_nodes'):
|
||||||
@ -131,7 +136,7 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=Fa
|
|||||||
try:
|
try:
|
||||||
module = entry_point.load()
|
module = entry_point.load()
|
||||||
except ModuleNotFoundError as module_not_found_error:
|
except ModuleNotFoundError as module_not_found_error:
|
||||||
logging.error(f"A module was not found while importing nodes via an entry point: {entry_point}. Please ensure the entry point in setup.py is named correctly", exc_info=module_not_found_error)
|
logger.error(f"A module was not found while importing nodes via an entry point: {entry_point}. Please ensure the entry point in setup.py is named correctly", exc_info=module_not_found_error)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Ensure that what we've loaded is indeed a module
|
# Ensure that what we've loaded is indeed a module
|
||||||
@ -146,5 +151,5 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=Fa
|
|||||||
# don't allow custom nodes to overwrite base nodes
|
# don't allow custom nodes to overwrite base nodes
|
||||||
custom_nodes_mappings -= base_and_extra
|
custom_nodes_mappings -= base_and_extra
|
||||||
|
|
||||||
_comfy_nodes.update(base_and_extra + custom_nodes_mappings)
|
_nodes_available_at_startup.update(base_and_extra + custom_nodes_mappings)
|
||||||
return _comfy_nodes
|
return _nodes_available_at_startup
|
||||||
|
|||||||
@ -7,8 +7,6 @@ from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, T
|
|||||||
|
|
||||||
from typing_extensions import TypedDict, NotRequired
|
from typing_extensions import TypedDict, NotRequired
|
||||||
|
|
||||||
from comfy.comfy_types import FileLocator
|
|
||||||
|
|
||||||
T = TypeVar('T')
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -131,6 +131,7 @@ dev = [
|
|||||||
"freezegun",
|
"freezegun",
|
||||||
"coverage",
|
"coverage",
|
||||||
"pylint",
|
"pylint",
|
||||||
|
"astroid",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|||||||
@ -24,18 +24,46 @@ class AbsoluteImportChecker(BaseChecker):
|
|||||||
super().__init__(linter)
|
super().__init__(linter)
|
||||||
|
|
||||||
def visit_importfrom(self, node: nodes.ImportFrom) -> None:
|
def visit_importfrom(self, node: nodes.ImportFrom) -> None:
|
||||||
current_file = node.root().file
|
"""
|
||||||
if current_file is None:
|
Check for absolute imports from the same top-level package.
|
||||||
|
|
||||||
|
This method is called for every `from ... import ...` statement.
|
||||||
|
It checks if a module within 'comfy' or 'comfy_extras' packages
|
||||||
|
is using an absolute import from its own package, which should
|
||||||
|
be a relative import instead.
|
||||||
|
|
||||||
|
For example, inside `comfy/nodes/logic.py`, an import like
|
||||||
|
`from comfy.utils import some_function` will be flagged.
|
||||||
|
The preferred way would be `from ..utils import some_function`.
|
||||||
|
"""
|
||||||
|
# An import is relative if its level is greater than 0.
|
||||||
|
# e.g., from . import foo (level=1), from .. import bar (level=2)
|
||||||
|
# We only want to check absolute imports, so we skip relative ones.
|
||||||
|
if node.level and node.level > 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
package_path = os.path.dirname(current_file)
|
# Get the fully qualified name of the module being linted.
|
||||||
package_name = os.path.basename(package_path)
|
# For a file at '.../comfy/nodes/common.py', this will be 'comfy.nodes.common'.
|
||||||
|
module_qname = node.root().qname()
|
||||||
|
|
||||||
if node.modname.startswith(package_name) and package_name in ['comfy', 'comfy_extras']:
|
# `node.modname` is the module name in the `from` statement.
|
||||||
import_parts = node.modname.split('.')
|
# For `from comfy.utils import x`, `modname` is `comfy.utils`.
|
||||||
|
imported_modname = node.modname
|
||||||
|
if not imported_modname:
|
||||||
|
return
|
||||||
|
|
||||||
if import_parts[0] == package_name:
|
# We are only interested in modules within 'comfy' or 'comfy_extras'.
|
||||||
self.add_message('absolute-import-used', node=node, args=(node.modname,))
|
# We determine this by looking at the first part of the qualified name.
|
||||||
|
current_top_package = module_qname.split('.')[0]
|
||||||
|
if current_top_package not in ['comfy', 'comfy_extras']:
|
||||||
|
return
|
||||||
|
|
||||||
|
imported_top_package = imported_modname.split('.')[0]
|
||||||
|
|
||||||
|
# If the top-level package of the imported module is the same as the
|
||||||
|
# current module's top-level package, it's an internal absolute import.
|
||||||
|
if imported_top_package == current_top_package:
|
||||||
|
self.add_message('absolute-import-used', node=node, args=(imported_modname,))
|
||||||
|
|
||||||
|
|
||||||
def register(linter: "PyLinter") -> None:
|
def register(linter: "PyLinter") -> None:
|
||||||
|
|||||||
@ -1,50 +1,25 @@
|
|||||||
import sys
|
|
||||||
import time
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import urllib
|
||||||
|
from typing import Tuple, List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
import socket
|
|
||||||
import subprocess
|
|
||||||
import urllib
|
|
||||||
from testcontainers.rabbitmq import RabbitMqContainer
|
|
||||||
from typing import Tuple, List
|
|
||||||
|
|
||||||
from comfy.cli_args_types import Configuration
|
from comfy.cli_args_types import Configuration
|
||||||
|
|
||||||
logging.getLogger("pika").setLevel(logging.CRITICAL + 1)
|
logging.getLogger("pika").setLevel(logging.CRITICAL + 1)
|
||||||
logging.getLogger("aio_pika").setLevel(logging.CRITICAL + 1)
|
logging.getLogger("aio_pika").setLevel(logging.CRITICAL + 1)
|
||||||
logging.getLogger("testcontainers.core.container").setLevel(logging.WARNING)
|
|
||||||
logging.getLogger("testcontainers.core.waiting_utils").setLevel(logging.WARNING)
|
|
||||||
|
|
||||||
# fixes issues with running the testcontainers rabbitmqcontainer on Windows
|
# fixes issues with running the testcontainers rabbitmqcontainer on Windows
|
||||||
os.environ["TC_HOST"] = "localhost"
|
os.environ["TC_HOST"] = "localhost"
|
||||||
|
|
||||||
|
|
||||||
def get_lan_ip():
|
|
||||||
"""
|
|
||||||
Finds the host's IP address on the LAN it's connected to.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: The IP address of the host on the LAN.
|
|
||||||
"""
|
|
||||||
# Create a dummy socket
|
|
||||||
s = None
|
|
||||||
try:
|
|
||||||
# Connect to a dummy address (Here, Google's public DNS server)
|
|
||||||
# The actual connection is not made, but this allows finding out the LAN IP
|
|
||||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
||||||
s.connect(("8.8.8.8", 80))
|
|
||||||
ip = s.getsockname()[0]
|
|
||||||
finally:
|
|
||||||
if s is not None:
|
|
||||||
s.close()
|
|
||||||
return ip
|
|
||||||
|
|
||||||
|
|
||||||
def run_server(server_arguments: Configuration):
|
def run_server(server_arguments: Configuration):
|
||||||
from comfy.cmd.main import main
|
from comfy.cmd.main import main
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
@ -95,6 +70,11 @@ def has_gpu() -> bool:
|
|||||||
@pytest.fixture(scope="module", autouse=False, params=["ThreadPoolExecutor", "ProcessPoolExecutor"])
|
@pytest.fixture(scope="module", autouse=False, params=["ThreadPoolExecutor", "ProcessPoolExecutor"])
|
||||||
def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers: int = 1):
|
def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers: int = 1):
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
from testcontainers.rabbitmq import RabbitMqContainer
|
||||||
|
|
||||||
|
logging.getLogger("testcontainers.core.container").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("testcontainers.core.waiting_utils").setLevel(logging.WARNING)
|
||||||
|
|
||||||
hf_hub_download("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors")
|
hf_hub_download("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors")
|
||||||
hf_hub_download("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors")
|
hf_hub_download("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors")
|
||||||
|
|
||||||
@ -108,8 +88,8 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers
|
|||||||
|
|
||||||
frontend_command = [
|
frontend_command = [
|
||||||
"comfyui",
|
"comfyui",
|
||||||
"--listen=0.0.0.0",
|
"--listen=127.0.0.1",
|
||||||
"--port=9001",
|
"--port=19001",
|
||||||
"--cpu",
|
"--cpu",
|
||||||
"--distributed-queue-frontend",
|
"--distributed-queue-frontend",
|
||||||
f"-w={str(tmp_path)}",
|
f"-w={str(tmp_path)}",
|
||||||
@ -122,7 +102,7 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers
|
|||||||
for i in range(num_workers):
|
for i in range(num_workers):
|
||||||
backend_command = [
|
backend_command = [
|
||||||
"comfyui-worker",
|
"comfyui-worker",
|
||||||
f"--port={9002 + i}",
|
f"--port={19002 + i}",
|
||||||
f"-w={str(tmp_path)}",
|
f"-w={str(tmp_path)}",
|
||||||
f"--distributed-queue-connection-uri={connection_uri}",
|
f"--distributed-queue-connection-uri={connection_uri}",
|
||||||
f"--executor-factory={executor_factory}"
|
f"--executor-factory={executor_factory}"
|
||||||
@ -130,7 +110,7 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers
|
|||||||
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr))
|
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
server_address = f"http://{get_lan_ip()}:9001"
|
server_address = f"http://127.0.0.1:19001"
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
connected = False
|
connected = False
|
||||||
while time.time() - start_time < 60:
|
while time.time() - start_time < 60:
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from comfy.component_model.executor_types import SendSyncEvent, SendSyncData, Ex
|
|||||||
DependencyCycleError
|
DependencyCycleError
|
||||||
from comfy.distributed.server_stub import ServerStub
|
from comfy.distributed.server_stub import ServerStub
|
||||||
from comfy.execution_context import context_add_custom_nodes
|
from comfy.execution_context import context_add_custom_nodes
|
||||||
from comfy.graph_utils import GraphBuilder, Node
|
from comfy_execution.graph_utils import GraphBuilder, Node
|
||||||
from comfy.nodes.package_typing import ExportedNodes
|
from comfy.nodes.package_typing import ExportedNodes
|
||||||
|
|
||||||
current_test_name = ContextVar('current_test_name', default=None)
|
current_test_name = ContextVar('current_test_name', default=None)
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from comfy.graph_utils import GraphBuilder, is_link
|
from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||||
from comfy.graph import ExecutionBlocker
|
from comfy.graph import ExecutionBlocker
|
||||||
from .tools import VariantSupport
|
from .tools import VariantSupport
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import time
|
|||||||
import asyncio
|
import asyncio
|
||||||
from comfy.utils import ProgressBar
|
from comfy.utils import ProgressBar
|
||||||
from .tools import VariantSupport
|
from .tools import VariantSupport
|
||||||
from comfy.graph_utils import GraphBuilder
|
from comfy_execution.graph_utils import GraphBuilder
|
||||||
from comfy.comfy_types.node_typing import ComfyNodeABC
|
from comfy.comfy_types.node_typing import ComfyNodeABC
|
||||||
from comfy.comfy_types import IO
|
from comfy.comfy_types import IO
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from comfy.graph_utils import GraphBuilder
|
from comfy_execution.graph_utils import GraphBuilder
|
||||||
from .tools import VariantSupport
|
from .tools import VariantSupport
|
||||||
|
|
||||||
@VariantSupport()
|
@VariantSupport()
|
||||||
|
|||||||
129
tests/main_pre_import_checker.py
Normal file
129
tests/main_pre_import_checker.py
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
"""Pylint checker for ensuring main_pre is imported first."""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
|
import astroid
|
||||||
|
from astroid import nodes
|
||||||
|
from pylint.checkers import BaseChecker
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pylint.lint import PyLinter
|
||||||
|
|
||||||
|
|
||||||
|
class MainPreImportOrderChecker(BaseChecker):
|
||||||
|
"""
|
||||||
|
Ensures that imports from 'comfy.cmd.main_pre' or similar setup modules
|
||||||
|
occur before any other relative imports or imports from the 'comfy' package.
|
||||||
|
|
||||||
|
This is important for code that relies on setup being performed by 'main_pre'
|
||||||
|
before other modules from the same package are imported.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = 'main-pre-import-order'
|
||||||
|
msgs = {
|
||||||
|
'W0002': (
|
||||||
|
'Setup import %s must be placed before other package imports like %s.',
|
||||||
|
'main-pre-import-not-first',
|
||||||
|
"To ensure necessary setup is performed, 'comfy.cmd.main_pre' or similar "
|
||||||
|
"setup imports must precede other relative imports or imports from the "
|
||||||
|
"'comfy' family of packages."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _is_main_pre_import(self, stmt: Union[nodes.Import, nodes.ImportFrom]) -> bool:
|
||||||
|
"""Checks if an import statement is for 'comfy.cmd.main_pre'."""
|
||||||
|
if isinstance(stmt, nodes.Import):
|
||||||
|
for name, _ in stmt.names:
|
||||||
|
if name == 'comfy.cmd.main_pre' or name.startswith('comfy.cmd.main_pre.'):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
if isinstance(stmt, nodes.ImportFrom):
|
||||||
|
qname: Optional[str] = None
|
||||||
|
if stmt.level == 0:
|
||||||
|
qname = stmt.modname
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
# Attempt to resolve the relative import to a fully qualified name
|
||||||
|
imported_module = stmt.do_import_module()
|
||||||
|
qname = imported_module.qname()
|
||||||
|
except astroid.AstroidError:
|
||||||
|
# Fallback for unresolved relative imports, check the literal module name
|
||||||
|
if stmt.modname and stmt.modname.endswith('.main_pre'):
|
||||||
|
return True
|
||||||
|
# Heuristic for `from ..cmd import main_pre` in `comfy/entrypoints/*`
|
||||||
|
if stmt.modname == 'cmd' and stmt.root().qname().startswith('comfy'):
|
||||||
|
for name, _ in stmt.names:
|
||||||
|
if name == 'main_pre':
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not qname:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# from comfy.cmd import main_pre
|
||||||
|
if qname == 'comfy.cmd':
|
||||||
|
for name, _ in stmt.names:
|
||||||
|
if name == 'main_pre':
|
||||||
|
return True
|
||||||
|
|
||||||
|
# from comfy.cmd.main_pre import ... OR from a.b.c.main_pre import ...
|
||||||
|
if qname == 'comfy.cmd.main_pre' or qname.endswith('.main_pre'):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _is_other_relevant_import(self, stmt: Union[nodes.Import, nodes.ImportFrom]) -> bool:
|
||||||
|
"""
|
||||||
|
Checks if an import is a relative import or an import from
|
||||||
|
the 'comfy' package family, and is not a 'main_pre' import.
|
||||||
|
"""
|
||||||
|
if self._is_main_pre_import(stmt):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if isinstance(stmt, nodes.ImportFrom):
|
||||||
|
if stmt.level and stmt.level > 0: # Any relative import
|
||||||
|
return True
|
||||||
|
if stmt.modname and stmt.modname.startswith('comfy'):
|
||||||
|
return True
|
||||||
|
|
||||||
|
if isinstance(stmt, nodes.Import):
|
||||||
|
for name, _ in stmt.names:
|
||||||
|
if name.startswith('comfy'):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def visit_module(self, node: nodes.Module) -> None:
|
||||||
|
"""Checks the order of imports within a module."""
|
||||||
|
imports = [
|
||||||
|
stmt for stmt in node.body
|
||||||
|
if isinstance(stmt, (nodes.Import, nodes.ImportFrom))
|
||||||
|
]
|
||||||
|
|
||||||
|
main_pre_import_node = None
|
||||||
|
for stmt in imports:
|
||||||
|
if self._is_main_pre_import(stmt):
|
||||||
|
main_pre_import_node = stmt
|
||||||
|
break
|
||||||
|
|
||||||
|
# If there's no main_pre import, there's nothing to check.
|
||||||
|
if not main_pre_import_node:
|
||||||
|
return
|
||||||
|
|
||||||
|
for stmt in imports:
|
||||||
|
# We only care about imports that appear before the main_pre import
|
||||||
|
if stmt.fromlineno >= main_pre_import_node.fromlineno:
|
||||||
|
break
|
||||||
|
|
||||||
|
if self._is_other_relevant_import(stmt):
|
||||||
|
self.add_message(
|
||||||
|
'main-pre-import-not-first',
|
||||||
|
node=main_pre_import_node,
|
||||||
|
args=(main_pre_import_node.as_string(), stmt.as_string())
|
||||||
|
)
|
||||||
|
return # Report once per file and exit to avoid spam
|
||||||
|
|
||||||
|
|
||||||
|
def register(linter: "PyLinter") -> None:
|
||||||
|
"""This function is required for a Pylint plugin."""
|
||||||
|
linter.register_checker(MainPreImportOrderChecker(linter))
|
||||||
292
tests/unit/prompt_server_test/user_manager_test.py
Normal file
292
tests/unit/prompt_server_test/user_manager_test.py
Normal file
@ -0,0 +1,292 @@
|
|||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from comfy.app.user_manager import UserManager
|
||||||
|
|
||||||
|
pytestmark = (
|
||||||
|
pytest.mark.asyncio
|
||||||
|
) # This applies the asyncio mark to all test functions in the module
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def user_manager(tmp_path):
|
||||||
|
um = UserManager()
|
||||||
|
um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join(
|
||||||
|
tmp_path, file
|
||||||
|
) if file else tmp_path
|
||||||
|
return um
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app(user_manager):
|
||||||
|
app = web.Application()
|
||||||
|
routes = web.RouteTableDef()
|
||||||
|
user_manager.add_routes(routes)
|
||||||
|
app.add_routes(routes)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_empty_directory(aiohttp_client, app, tmp_path):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=test_dir")
|
||||||
|
assert resp.status == 404
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_with_files(aiohttp_client, app, tmp_path):
|
||||||
|
os.makedirs(tmp_path / "test_dir")
|
||||||
|
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=test_dir")
|
||||||
|
assert resp.status == 200
|
||||||
|
assert await resp.json() == ["file1.txt"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_recursive(aiohttp_client, app, tmp_path):
|
||||||
|
os.makedirs(tmp_path / "test_dir" / "subdir")
|
||||||
|
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
with open(tmp_path / "test_dir" / "subdir" / "file2.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=test_dir&recurse=true")
|
||||||
|
assert resp.status == 200
|
||||||
|
assert set(await resp.json()) == {"file1.txt", "subdir/file2.txt"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_full_info(aiohttp_client, app, tmp_path):
|
||||||
|
os.makedirs(tmp_path / "test_dir")
|
||||||
|
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=test_dir&full_info=true")
|
||||||
|
assert resp.status == 200
|
||||||
|
result = await resp.json()
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["path"] == "file1.txt"
|
||||||
|
assert "size" in result[0]
|
||||||
|
assert "modified" in result[0]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_split_path(aiohttp_client, app, tmp_path):
|
||||||
|
os.makedirs(tmp_path / "test_dir" / "subdir")
|
||||||
|
with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=test_dir&recurse=true&split=true")
|
||||||
|
assert resp.status == 200
|
||||||
|
assert await resp.json() == [["subdir/file1.txt", "subdir", "file1.txt"]]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_invalid_directory(aiohttp_client, app):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=")
|
||||||
|
assert resp.status == 400
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_normalized_separator(aiohttp_client, app, tmp_path):
|
||||||
|
os_sep = "\\"
|
||||||
|
with patch("os.sep", os_sep):
|
||||||
|
with patch("os.path.sep", os_sep):
|
||||||
|
os.makedirs(tmp_path / "test_dir" / "subdir")
|
||||||
|
with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/userdata?dir=test_dir&recurse=true")
|
||||||
|
assert resp.status == 200
|
||||||
|
result = await resp.json()
|
||||||
|
assert len(result) == 1
|
||||||
|
assert "/" in result[0] # Ensure forward slash is used
|
||||||
|
assert "\\" not in result[0] # Ensure backslash is not present
|
||||||
|
assert result[0] == "subdir/file1.txt"
|
||||||
|
|
||||||
|
# Test with full_info
|
||||||
|
resp = await client.get(
|
||||||
|
"/userdata?dir=test_dir&recurse=true&full_info=true"
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
result = await resp.json()
|
||||||
|
assert len(result) == 1
|
||||||
|
assert "/" in result[0]["path"] # Ensure forward slash is used
|
||||||
|
assert "\\" not in result[0]["path"] # Ensure backslash is not present
|
||||||
|
assert result[0]["path"] == "subdir/file1.txt"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_post_userdata_new_file(aiohttp_client, app, tmp_path):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
content = b"test content"
|
||||||
|
resp = await client.post("/userdata/test.txt", data=content)
|
||||||
|
|
||||||
|
assert resp.status == 200
|
||||||
|
assert await resp.text() == '"test.txt"'
|
||||||
|
|
||||||
|
# Verify file was created with correct content
|
||||||
|
with open(tmp_path / "test.txt", "rb") as f:
|
||||||
|
assert f.read() == content
|
||||||
|
|
||||||
|
|
||||||
|
async def test_post_userdata_overwrite_existing(aiohttp_client, app, tmp_path):
|
||||||
|
# Create initial file
|
||||||
|
with open(tmp_path / "test.txt", "w") as f:
|
||||||
|
f.write("initial content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
new_content = b"updated content"
|
||||||
|
resp = await client.post("/userdata/test.txt", data=new_content)
|
||||||
|
|
||||||
|
assert resp.status == 200
|
||||||
|
assert await resp.text() == '"test.txt"'
|
||||||
|
|
||||||
|
# Verify file was overwritten
|
||||||
|
with open(tmp_path / "test.txt", "rb") as f:
|
||||||
|
assert f.read() == new_content
|
||||||
|
|
||||||
|
|
||||||
|
async def test_post_userdata_no_overwrite(aiohttp_client, app, tmp_path):
|
||||||
|
# Create initial file
|
||||||
|
with open(tmp_path / "test.txt", "w") as f:
|
||||||
|
f.write("initial content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post("/userdata/test.txt?overwrite=false", data=b"new content")
|
||||||
|
|
||||||
|
assert resp.status == 409
|
||||||
|
|
||||||
|
# Verify original content unchanged
|
||||||
|
with open(tmp_path / "test.txt", "r") as f:
|
||||||
|
assert f.read() == "initial content"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_post_userdata_full_info(aiohttp_client, app, tmp_path):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
content = b"test content"
|
||||||
|
resp = await client.post("/userdata/test.txt?full_info=true", data=content)
|
||||||
|
|
||||||
|
assert resp.status == 200
|
||||||
|
result = await resp.json()
|
||||||
|
assert result["path"] == "test.txt"
|
||||||
|
assert result["size"] == len(content)
|
||||||
|
assert "modified" in result
|
||||||
|
|
||||||
|
|
||||||
|
async def test_move_userdata(aiohttp_client, app, tmp_path):
|
||||||
|
# Create initial file
|
||||||
|
with open(tmp_path / "source.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post("/userdata/source.txt/move/dest.txt")
|
||||||
|
|
||||||
|
assert resp.status == 200
|
||||||
|
assert await resp.text() == '"dest.txt"'
|
||||||
|
|
||||||
|
# Verify file was moved
|
||||||
|
assert not os.path.exists(tmp_path / "source.txt")
|
||||||
|
with open(tmp_path / "dest.txt", "r") as f:
|
||||||
|
assert f.read() == "test content"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_move_userdata_no_overwrite(aiohttp_client, app, tmp_path):
|
||||||
|
# Create source and destination files
|
||||||
|
with open(tmp_path / "source.txt", "w") as f:
|
||||||
|
f.write("source content")
|
||||||
|
with open(tmp_path / "dest.txt", "w") as f:
|
||||||
|
f.write("destination content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post("/userdata/source.txt/move/dest.txt?overwrite=false")
|
||||||
|
|
||||||
|
assert resp.status == 409
|
||||||
|
|
||||||
|
# Verify files remain unchanged
|
||||||
|
with open(tmp_path / "source.txt", "r") as f:
|
||||||
|
assert f.read() == "source content"
|
||||||
|
with open(tmp_path / "dest.txt", "r") as f:
|
||||||
|
assert f.read() == "destination content"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_move_userdata_full_info(aiohttp_client, app, tmp_path):
|
||||||
|
# Create initial file
|
||||||
|
with open(tmp_path / "source.txt", "w") as f:
|
||||||
|
f.write("test content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post("/userdata/source.txt/move/dest.txt?full_info=true")
|
||||||
|
|
||||||
|
assert resp.status == 200
|
||||||
|
result = await resp.json()
|
||||||
|
assert result["path"] == "dest.txt"
|
||||||
|
assert result["size"] == len("test content")
|
||||||
|
assert "modified" in result
|
||||||
|
|
||||||
|
# Verify file was moved
|
||||||
|
assert not os.path.exists(tmp_path / "source.txt")
|
||||||
|
with open(tmp_path / "dest.txt", "r") as f:
|
||||||
|
assert f.read() == "test content"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_v2_empty_root(aiohttp_client, app):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/v2/userdata")
|
||||||
|
assert resp.status == 200
|
||||||
|
assert await resp.json() == []
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_v2_nonexistent_subdirectory(aiohttp_client, app):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/v2/userdata?path=does_not_exist")
|
||||||
|
assert resp.status == 404
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_v2_default(aiohttp_client, app, tmp_path):
|
||||||
|
os.makedirs(tmp_path / "test_dir" / "subdir")
|
||||||
|
(tmp_path / "test_dir" / "file1.txt").write_text("content")
|
||||||
|
(tmp_path / "test_dir" / "subdir" / "file2.txt").write_text("content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/v2/userdata?path=test_dir")
|
||||||
|
assert resp.status == 200
|
||||||
|
data = await resp.json()
|
||||||
|
file_paths = {item["path"] for item in data if item["type"] == "file"}
|
||||||
|
assert file_paths == {"test_dir/file1.txt", "test_dir/subdir/file2.txt"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_v2_normalized_separators(aiohttp_client, app, tmp_path, monkeypatch):
|
||||||
|
# Force backslash as os separator
|
||||||
|
monkeypatch.setattr(os, 'sep', '\\')
|
||||||
|
monkeypatch.setattr(os.path, 'sep', '\\')
|
||||||
|
os.makedirs(tmp_path / "test_dir" / "subdir")
|
||||||
|
(tmp_path / "test_dir" / "subdir" / "file1.txt").write_text("x")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/v2/userdata?path=test_dir")
|
||||||
|
assert resp.status == 200
|
||||||
|
data = await resp.json()
|
||||||
|
for item in data:
|
||||||
|
assert "/" in item["path"]
|
||||||
|
assert "\\" not in item["path"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_listuserdata_v2_url_encoded_path(aiohttp_client, app, tmp_path):
|
||||||
|
# Create a directory with a space in its name and a file inside
|
||||||
|
os.makedirs(tmp_path / "my dir")
|
||||||
|
(tmp_path / "my dir" / "file.txt").write_text("content")
|
||||||
|
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
# Use URL-encoded space in path parameter
|
||||||
|
resp = await client.get("/v2/userdata?path=my%20dir&recurse=false")
|
||||||
|
assert resp.status == 200
|
||||||
|
data = await resp.json()
|
||||||
|
assert len(data) == 1
|
||||||
|
entry = data[0]
|
||||||
|
assert entry["name"] == "file.txt"
|
||||||
|
# Ensure the path is correctly decoded and uses forward slash
|
||||||
|
assert entry["path"] == "my dir/file.txt"
|
||||||
@ -122,7 +122,7 @@ def test_bool_request_parameter():
|
|||||||
assert v == True
|
assert v == True
|
||||||
|
|
||||||
|
|
||||||
def test_string_enum_request_parameter():
|
async def test_string_enum_request_parameter():
|
||||||
nt = StringEnumRequestParameter.INPUT_TYPES()
|
nt = StringEnumRequestParameter.INPUT_TYPES()
|
||||||
assert nt is not None
|
assert nt is not None
|
||||||
n = StringEnumRequestParameter()
|
n = StringEnumRequestParameter()
|
||||||
@ -155,8 +155,9 @@ def test_string_enum_request_parameter():
|
|||||||
}
|
}
|
||||||
from comfy.cmd.execution import validate_inputs
|
from comfy.cmd.execution import validate_inputs
|
||||||
validated: dict[str, ValidateInputsTuple] = {}
|
validated: dict[str, ValidateInputsTuple] = {}
|
||||||
validated["1"] = validate_inputs(prompt, "1", validated)
|
prompt_id = str(uuid.uuid4())
|
||||||
validated["2"] = validate_inputs(prompt, "2", validated)
|
validated["1"] = await validate_inputs(prompt_id, prompt, "1", validated)
|
||||||
|
validated["2"] = await validate_inputs(prompt_id, prompt, "2", validated)
|
||||||
assert validated["2"].valid
|
assert validated["2"].valid
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user