packaging fixes

- enable user db
 - fix main_pre order everywhere
 - fix absolute to relative imports everywhere
 - async better supported
This commit is contained in:
doctorpangloss 2025-07-15 10:19:33 -07:00
parent c086c5e005
commit 96b4e04315
39 changed files with 604 additions and 151 deletions

View File

@ -82,7 +82,7 @@ limit-inference-results=100
# List of plugins (as comma separated values of python module names) to load,
# usually to register additional checkers.
load-plugins=tests.absolute_import_checker
load-plugins=tests.absolute_import_checker,tests.main_pre_import_checker
# Pickle collected data for later comparisons.
persistent=yes
@ -678,7 +678,7 @@ disable=raw-checker-failed,
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
enable=absolute-import-used
enable=
[METHOD_ARGS]

View File

@ -13,7 +13,7 @@ script_location = alembic_db
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
@ -63,7 +63,7 @@ version_path_separator = os
# are written from script.py.mako
# output_encoding = utf-8
sqlalchemy.url = sqlite:///user/comfyui.db
sqlalchemy.url = sqlite:///./comfyui.db
[post_write_hooks]

View File

View File

@ -1,3 +1,4 @@
# pylint: disable=no-member
from sqlalchemy import engine_from_config
from sqlalchemy import pool
@ -7,8 +8,7 @@ from alembic import context
# access to the values within the .ini file in use.
config = context.config
from comfy.app.database.models import Base
from ..app.database.models import Base
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,

View File

@ -1,8 +1,10 @@
import logging
import os
import shutil
from importlib.resources import files
from ...cli_args import args
from ...component_model.files import get_package_as_path
Session = None
@ -15,6 +17,7 @@ from sqlalchemy.orm import sessionmaker
_DB_AVAILABLE = True
logger = logging.getLogger(__name__)
def dependencies_available():
"""
@ -32,9 +35,8 @@ def can_create_session():
def get_alembic_config():
root_path = os.path.join(os.path.dirname(__file__), "../..")
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
config_path = str(files("comfy") / "alembic.ini")
scripts_path = get_package_as_path("comfy.alembic_db")
config = Config(config_path)
config.set_main_option("script_location", scripts_path)
@ -53,7 +55,7 @@ def get_db_path():
def init_db():
db_url = args.database_url
logging.debug(f"Database URL: {db_url}")
logger.debug(f"Database URL: {db_url}")
db_path = get_db_path()
db_exists = os.path.exists(db_path)
@ -70,7 +72,7 @@ def init_db():
target_rev = script.get_current_head()
if target_rev is None:
logging.warning("No target revision found.")
logger.debug("No target revision found.")
elif current_rev != target_rev:
# Backup the database pre upgrade
backup_path = db_path + ".bkp"
@ -81,13 +83,13 @@ def init_db():
try:
command.upgrade(config, target_rev)
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
logger.info(f"Database upgraded from {current_rev} to {target_rev}")
except Exception as e:
if backup_path:
# Restore the database from backup if upgrade fails
shutil.copy(backup_path, db_path)
os.remove(backup_path)
logging.exception("Error upgrading database: ")
logger.exception("Error upgrading database: ")
raise e
global Session

View File

@ -5,7 +5,8 @@ import torch
import torch.nn as nn
from torch import Tensor
from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch
from ..ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, \
get_2d_sincos_pos_embed_torch
class ControlNetEmbedder(nn.Module):

View File

@ -11,7 +11,7 @@ import configargparse as argparse
from . import __version__
from . import options
from .cli_args_types import LatentPreviewMethod, Configuration, ConfigurationExtender, EnumAction, \
EnhancedConfigArgParser, PerformanceFeature, is_valid_directory
EnhancedConfigArgParser, PerformanceFeature, is_valid_directory, db_config
# todo: move this
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
@ -261,8 +261,8 @@ def _create_parser() -> EnhancedConfigArgParser:
help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
)
parser.add_argument("--database-url", type=str, default=f"sqlite:///:memory:", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
default_db_url = db_config()
parser.add_argument("--database-url", type=str, default=default_db_url, help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
parser.add_argument("--workflows", type=str, nargs='+', default=[], help="Execute the API workflow(s) specified in the provided files. For each workflow, its outputs will be printed to a line to standard out. Application logging will be redirected to standard error. Use `-` to signify standard in.")
# now give plugins a chance to add configuration

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import enum
import logging
import os
from typing import Optional, List, Callable, Any, Union, Mapping, NamedTuple
@ -20,6 +21,22 @@ class LatentPreviewMethod(enum.Enum):
ConfigObserver = Callable[[str, Any], None]
def db_config() -> str:
from .vendor.appdirs import user_data_dir
logger = logging.getLogger(__name__)
try:
data_dir = user_data_dir(appname="comfyui")
os.makedirs(data_dir, exist_ok=True)
db_path = os.path.join(data_dir, "comfy.db")
default_db_url = f"sqlite:///{db_path}"
except Exception as e:
# Fallback to an in-memory database if the user directory can't be accessed
logger.warning(f"Could not determine user data directory for database, falling back to in-memory: {e}")
default_db_url = "sqlite:///:memory:"
return default_db_url
def is_valid_directory(path: str) -> str:
"""Validate if the given path is a directory, and check permissions."""
if not os.path.exists(path):
@ -261,7 +278,7 @@ class Configuration(dict):
self.front_end_version: str = "comfyanonymous/ComfyUI@latest"
self.front_end_root: Optional[str] = None
self.comfy_api_base: str = "https://api.comfy.org"
self.database_url: str = "sqlite:///:memory:"
self.database_url: str = db_config()
for key, value in kwargs.items():
self[key] = value

View File

@ -14,11 +14,11 @@ from opentelemetry import context, propagate
from opentelemetry.context import Context, attach, detach
from opentelemetry.trace import Status, StatusCode
from ..cmd.main_pre import tracer
from .client_types import V1QueuePromptResponse
from ..api.components.schema.prompt import PromptDict
from ..cli_args_types import Configuration
from ..cmd.folder_paths import init_default_paths # pylint: disable=import-error
from ..cmd.main_pre import tracer
from ..component_model.executor_types import ExecutorToClientProgress
from ..component_model.make_mutable import make_mutable
from ..distributed.executors import ContextVarExecutor
@ -97,7 +97,7 @@ async def __execute_prompt(
else:
prompt_executor.server = progress_handler
prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id},
await prompt_executor.execute_async(prompt_mut, prompt_id, {"client_id": client_id},
execute_outputs=validation_tuple.good_output_node_ids)
return prompt_executor.outputs_ui
except Exception as exc_info:
@ -195,7 +195,7 @@ class Comfy:
if isinstance(prompt, str):
prompt = json.loads(prompt)
if isinstance(prompt, dict):
from comfy.api.components.schema.prompt import Prompt
from ..api.components.schema.prompt import Prompt
prompt = Prompt.validate(prompt)
outputs = await self.queue_prompt(prompt)
return V1QueuePromptResponse(urls=[], outputs=outputs)

View File

@ -19,14 +19,15 @@ from typing import List, Optional, Tuple, Literal
import torch
from opentelemetry.trace import get_current_span, StatusCode, Status
# order matters
from .main_pre import tracer
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, \
DependencyAwareCache, \
BasicCache
# order matters
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.utils import CurrentNodeContext
from .main_pre import tracer
from .. import interruption
from .. import model_management
from ..cli_args import args
@ -37,7 +38,7 @@ from ..component_model.executor_types import ExecutorToClientProgress, Validatio
HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage
from ..component_model.files import canonicalize_path
from ..component_model.module_property import create_module_properties
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus, ExecutionStatusAsDict
from ..execution_context import context_execute_node, context_execute_prompt
from ..execution_ext import should_panic_on_exception
from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
@ -388,7 +389,7 @@ def format_value(x) -> FormattedValue:
return str(x.__class__)
async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches, _node_id: str, extra_data: dict, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple:
async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches, node_id: str, extra_data: dict, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple:
"""
:param server:
@ -402,8 +403,8 @@ async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, ca
:param pending_subgraph_results:
:return:
"""
with context_execute_node(_node_id):
return _execute(server, dynprompt, caches, _node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
with context_execute_node(node_id):
return await _execute(server, dynprompt, caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple:
@ -516,7 +517,7 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
unblock()
asyncio.create_task(await_completion())
return (ExecutionResult.PENDING, None, None)
return RecursiveExecutionTuple(ExecutionResult.PENDING, None, None)
if len(output_ui) > 0:
caches.ui.set(unique_id, {
"meta": {
@ -685,11 +686,6 @@ class PromptExecutor:
if ex is not None and self.raise_exceptions:
raise ex
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
asyncio_loop = asyncio.new_event_loop()
asyncio.set_event_loop(asyncio_loop)
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
# torchao and potentially other optimization approaches break when the models are created in inference mode
# todo: this should really be backpropagated to code which creates ModelPatchers via lazy evaluation rather than globally checked here
@ -1206,7 +1202,7 @@ class PromptQueue(AbstractPromptQueue):
self.server.queue_updated()
return copy.deepcopy(item_with_future.queue_tuple), task_id
def task_done(self, item_id: str, outputs: dict,
def task_done(self, item_id: str, outputs: HistoryResultDict,
status: Optional[ExecutionStatus]):
history_result = outputs
with self.mutex:
@ -1215,9 +1211,9 @@ class PromptQueue(AbstractPromptQueue):
if len(self.history) > MAXIMUM_HISTORY_SIZE:
self.history.pop(next(iter(self.history)))
status_dict: Optional[dict] = None
status_dict = None
if status is not None:
status_dict = copy.deepcopy(ExecutionStatus(*status)._asdict())
status_dict: Optional[ExecutionStatusAsDict] = status.as_dict()
outputs_ = history_result["outputs"]
# Remove sensitive data from extra_data before storing in history
@ -1225,11 +1221,13 @@ class PromptQueue(AbstractPromptQueue):
if sensitive_val in prompt[3]:
prompt[3].pop(sensitive_val)
self.history[prompt[1]] = {
history_entry: HistoryEntry = {
"prompt": prompt,
"outputs": copy.deepcopy(outputs_),
'status': status_dict,
}
if status_dict is not None:
history_entry["status"] = status_dict
self.history[prompt[1]] = history_entry
self.history[prompt[1]].update(history_result)
self.server.queue_updated()
if queue_item.completed:

View File

@ -10,20 +10,22 @@ import time
from pathlib import Path
from typing import Optional
from comfy.component_model.entrypoints_common import configure_application_paths, executor_from_args
# main_pre must be the earliest import
from .main_pre import args
from . import hook_breaker_ac10a0
from .extra_model_paths import load_extra_path_config
# main_pre must be the earliest import since it suppresses some spurious warnings
from .main_pre import args
from .. import model_management
from ..analytics.analytics import initialize_event_tracking
from ..cmd import cuda_malloc
from ..cmd import folder_paths
from ..cmd import server as server_module
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.entrypoints_common import configure_application_paths, executor_from_args
from ..distributed.distributed_prompt_queue import DistributedPromptQueue
from ..distributed.server_stub import ServerStub
from ..nodes.package import import_all_nodes_in_workspace
from ..nodes_context import get_nodes
logger = logging.getLogger(__name__)
@ -42,6 +44,10 @@ def cuda_malloc_warning():
def prompt_worker(q: AbstractPromptQueue, server_instance: server_module.PromptServer):
asyncio.run(_prompt_worker(q, server_instance))
async def _prompt_worker(q: AbstractPromptQueue, server_instance: server_module.PromptServer):
from ..cmd import execution
from ..component_model import queue_types
from .. import model_management
@ -68,7 +74,7 @@ def prompt_worker(q: AbstractPromptQueue, server_instance: server_module.PromptS
prompt_id = item[1]
server_instance.last_prompt_id = prompt_id
e.execute(item[2], prompt_id, item[3], item[4])
await e.execute_async(item[2], prompt_id, item[3], item[4])
need_gc = True
q.task_done(item_id,
e.history_result,
@ -174,17 +180,16 @@ async def _start_comfyui(from_script_dir: Optional[Path] = None):
for config_path in itertools.chain(*args.extra_model_paths_config):
load_extra_path_config(config_path)
# always create directories when started interactively
folder_paths.create_directories()
if args.create_directories:
# then, import and exit
import_all_nodes_in_workspace(raise_on_failure=False)
folder_paths.create_directories()
exit(0)
setup_database()
elif args.quick_test_for_ci:
import_all_nodes_in_workspace(raise_on_failure=True)
exit(0)
if args.windows_standalone_build:
folder_paths.create_directories()
try:
from . import new_updater
new_updater.update_windows_updater()
@ -198,7 +203,7 @@ async def _start_comfyui(from_script_dir: Optional[Path] = None):
# at this stage, it's safe to import nodes
hook_breaker_ac10a0.save_functions()
server.nodes = import_all_nodes_in_workspace()
server.nodes = get_nodes()
hook_breaker_ac10a0.restore_functions()
# as a side effect, this also populates the nodes for execution
@ -221,6 +226,7 @@ async def _start_comfyui(from_script_dir: Optional[Path] = None):
server.add_routes()
cuda_malloc_warning()
setup_database()
# in a distributed setting, the default prompt worker will not be able to send execution events via the websocket
worker_thread_server = server if not distributed else ServerStub()
@ -254,14 +260,7 @@ async def _start_comfyui(from_script_dir: Optional[Path] = None):
logger.debug(f"Setting input directory to: {input_dir}")
folder_paths.set_input_directory(input_dir)
if args.quick_test_for_ci:
# for CI purposes, try importing all the nodes
import_all_nodes_in_workspace(raise_on_failure=True)
return
else:
# we no longer lazily load nodes. we'll do it now for the sake of creating directories
import_all_nodes_in_workspace(raise_on_failure=False)
# now that nodes are loaded, create more directories if appropriate
# now that nodes are loaded, create directories
folder_paths.create_directories()
if len(args.workflows) > 0:

View File

@ -1169,7 +1169,7 @@ class PromptServer(ExecutorToClientProgress):
await runner.setup()
if 'tls_keyfile' in args or 'tls_certfile' in args:
raise ValueError("Use caddy instead of aiohttp to serve https by setting up a reverse proxy. See README.md")
logger.warning("Use caddy instead of aiohttp to serve https by setting up a reverse proxy. See README.md")
def is_ipv4(address: str, *args):
try:

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import typing
from abc import ABCMeta, abstractmethod
from .executor_types import HistoryResultDict
from .queue_types import QueueTuple, HistoryEntry, QueueItem, Flags, ExecutionStatus, TaskInvocation, AbstractPromptQueueGetCurrentQueueItems
@ -42,7 +43,7 @@ class AbstractPromptQueue(metaclass=ABCMeta):
pass
@abstractmethod
def task_done(self, item_id: str, outputs: dict,
def task_done(self, item_id: str, outputs: HistoryResultDict,
status: typing.Optional[ExecutionStatus]):
"""
Signals to the user interface that the task with the specified id is completed

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import copy
from enum import Enum
from typing import NamedTuple, Optional, List, Literal, Sequence
from typing import Tuple
@ -24,6 +25,13 @@ class ExecutionStatus(NamedTuple):
completed: bool
messages: List[str]
def as_dict(self) -> ExecutionStatusAsDict:
return {
"status_str": self.status_str,
"completed": self.completed,
"messages": copy.copy(self.messages),
}
class ExecutionError(RuntimeError):
def __init__(self, task_id: int | str, status: Optional[ExecutionStatus] = None, exceptions: Optional[Sequence[Exception]] = None, *args):

View File

@ -13,12 +13,12 @@ from aio_pika import connect_robust
from aio_pika.abc import AbstractConnection, AbstractChannel
from aio_pika.patterns import JsonRPC
from ..cmd.main_pre import tracer
from .distributed_progress import ProgressHandlers
from .distributed_types import RpcRequest, RpcReply
from .history import History
from .server_stub import ServerStub
from ..auth.permissions import jwt_decode
from ..cmd.main_pre import tracer
from ..cmd.server import PromptServer
from ..component_model.abstract_prompt_queue import AsyncAbstractPromptQueue, AbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData, HistoryResultDict

View File

@ -10,12 +10,12 @@ from aio_pika.patterns import JsonRPC
from aiohttp import web
from aiormq import AMQPConnectionError
from ..cmd.main_pre import tracer
from .executors import ContextVarExecutor
from .distributed_progress import DistributedExecutorToClientProgress
from .distributed_types import RpcRequest, RpcReply
from .process_pool_executor import ProcessPoolExecutor
from ..client.embedded_comfy_client import Comfy
from ..cmd.main_pre import tracer
from ..component_model.queue_types import ExecutionStatus
logger = logging.getLogger(__name__)

View File

@ -2,7 +2,6 @@ import asyncio
from ..cmd.main_pre import args
from ..component_model.entrypoints_common import configure_application_paths, executor_from_args
from ..distributed.executors import ContextVarExecutor, ContextVarProcessPoolExecutor
async def main():

View File

@ -1,7 +1,8 @@
import torch
from comfy.text_encoders.bert import BertAttention
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention_for_device
from ..ldm.modules.attention import optimized_attention_for_device
from ..model_management import cast_to_device
from ..text_encoders.bert import BertAttention
class Dino2AttentionOutput(torch.nn.Module):
@ -29,7 +30,7 @@ class LayerScale(torch.nn.Module):
self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
def forward(self, x):
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
return x * cast_to_device(self.lambda1, x.device, x.dtype)
class SwiGLUFFN(torch.nn.Module):
@ -117,7 +118,7 @@ class Dino2Embeddings(torch.nn.Module):
x = self.patch_embeddings(pixel_values)
# TODO: mask_token?
x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1)
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
x = x + cast_to_device(self.position_embeddings, x.device, x.dtype)
return x

View File

@ -1,7 +1,6 @@
from __future__ import annotations
import logging
from importlib.abc import Traversable # pylint: disable=no-name-in-module
from importlib.resources import files
from pathlib import Path
@ -10,7 +9,7 @@ KNOWN_CHAT_TEMPLATES = {}
def _update_known_chat_templates():
try:
_chat_templates: Traversable = files(__package__) / "chat_templates"
_chat_templates = files(__package__) / "chat_templates"
_extra_jinja_templates = {Path(traversable.name).stem: traversable.read_text().replace(' ', '').replace('\n', '') for traversable in _chat_templates.iterdir() if traversable.is_file()}
KNOWN_CHAT_TEMPLATES.update(_extra_jinja_templates)
except ImportError as exc:

View File

@ -10,7 +10,7 @@ from transformers.tokenization_utils_base import TextInput, PreTokenizedInput, T
from transformers.utils import PaddingStrategy
from typing_extensions import TypedDict, NotRequired
from comfy.component_model.tensor_types import RGBImageBatch
from ..component_model.tensor_types import RGBImageBatch
class ProcessorResult(TypedDict):

View File

@ -23,7 +23,7 @@ from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch import nn
from comfy.ldm.modules.attention import optimized_attention
from ..modules.attention import optimized_attention
def get_normalization(name: str, channels: int, weight_args={}, operations=None):

View File

@ -11,7 +11,7 @@ import math
from .position_embedding import VideoRopePosition3DEmb, LearnablePosEmbAxis
from torchvision import transforms
from comfy.ldm.modules.attention import optimized_attention
from ..modules.attention import optimized_attention
def apply_rotary_pos_emb(
t: torch.Tensor,

View File

@ -1,12 +1,6 @@
import torch
from torch import nn
from comfy.ldm.flux.layers import (
DoubleStreamBlock,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
)
from ..flux.layers import DoubleStreamBlock, LastLayer, MLPEmbedder, SingleStreamBlock, timestep_embedding
class Hunyuan3Dv2(nn.Module):

View File

@ -5,10 +5,10 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from comfy.ldm.modules.diffusionmodules.model import vae_attention
from ..modules.diffusionmodules.model import vae_attention
import comfy.ops
ops = comfy.ops.disable_weight_init
from ...ops import disable_weight_init
ops = disable_weight_init
CACHE_T = 2

View File

@ -31,9 +31,9 @@ import psutil
import torch
from opentelemetry.trace import get_current_span
from .cmd.main_pre import tracer
from . import interruption
from .cli_args import args, PerformanceFeature
from .cmd.main_pre import tracer
from .component_model.deprecation import _deprecate_method
from .model_management_types import ModelManageable

View File

@ -11,12 +11,13 @@ from importlib.metadata import entry_points
from opentelemetry.trace import Span, Status, StatusCode
from .package_typing import ExportedNodes
from ..cmd.main_pre import tracer
from .package_typing import ExportedNodes
from ..component_model.files import get_package_as_path
_comfy_nodes: ExportedNodes = ExportedNodes()
_nodes_available_at_startup: ExportedNodes = ExportedNodes()
logger = logging.getLogger(__name__)
def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleType):
node_class_mappings = getattr(module, 'NODE_CLASS_MAPPINGS', None)
@ -55,7 +56,7 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
timings.append((time.perf_counter() - time_before, full_name, True, exported_nodes))
except Exception as exc:
module_decl = None
logging.error(f"{full_name} import failed", exc_info=exc)
logger.error(f"{full_name} import failed", exc_info=exc)
span.set_status(Status(StatusCode.ERROR))
span.record_exception(exc)
exceptions.append(exc)
@ -84,7 +85,7 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
potential_path_error: AttributeError = x
if potential_path_error.name == '__path__':
continue
logging.error(f"{full_name} import failed", exc_info=x)
logger.error(f"{full_name} import failed", exc_info=x)
success = False
exceptions.append(x)
span.set_status(Status(StatusCode.ERROR))
@ -93,7 +94,7 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
if print_import_times and len(timings) > 0 or any(not success for (_, _, success, _) in timings):
for (duration, module_name, success, new_nodes) in sorted(timings):
logging.log(logging.DEBUG if success else logging.ERROR, f"{duration:6.1f} seconds{'' if success else ' (IMPORT FAILED)'}, {module_name} ({len(new_nodes)} nodes loaded)")
logger.log(logging.DEBUG if success else logging.ERROR, f"{duration:6.1f} seconds{'' if success else ' (IMPORT FAILED)'}, {module_name} ({len(new_nodes)} nodes loaded)")
if raise_on_failure and len(exceptions) > 0:
try:
raise ExceptionGroup("Node import failed", exceptions) # pylint: disable=using-exception-groups-in-unsupported-version
@ -105,12 +106,16 @@ def _import_and_enumerate_nodes_in_module(module: types.ModuleType,
@tracer.start_as_current_span("Import All Nodes In Workspace")
def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=False) -> ExportedNodes:
# now actually import the nodes, to improve control of node loading order
from comfy_extras import nodes as comfy_extras_nodes # pylint: disable=absolute-import-used
from ..cli_args import args
from . import base_nodes
from .vanilla_node_importing import mitigated_import_of_vanilla_custom_nodes
# only load these nodes once
if len(_comfy_nodes) == 0:
if len(_nodes_available_at_startup) == 0:
# import base_nodes first
from . import base_nodes
from comfy_extras import nodes as comfy_extras_nodes # pylint: disable=absolute-import-used
from .vanilla_node_importing import mitigated_import_of_vanilla_custom_nodes
base_and_extra = reduce(lambda x, y: x.update(y),
map(lambda module_inner: _import_and_enumerate_nodes_in_module(module_inner, raise_on_failure=raise_on_failure), [
# this is the list of default nodes to import
@ -121,9 +126,9 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=Fa
custom_nodes_mappings = ExportedNodes()
if args.disable_all_custom_nodes:
logging.info("Loading custom nodes was disabled, only base and extra nodes were loaded")
_comfy_nodes.update(base_and_extra)
return _comfy_nodes
logger.info("Loading custom nodes was disabled, only base and extra nodes were loaded")
_nodes_available_at_startup.update(base_and_extra)
return _nodes_available_at_startup
# load from entrypoints
for entry_point in entry_points().select(group='comfyui.custom_nodes'):
@ -131,7 +136,7 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=Fa
try:
module = entry_point.load()
except ModuleNotFoundError as module_not_found_error:
logging.error(f"A module was not found while importing nodes via an entry point: {entry_point}. Please ensure the entry point in setup.py is named correctly", exc_info=module_not_found_error)
logger.error(f"A module was not found while importing nodes via an entry point: {entry_point}. Please ensure the entry point in setup.py is named correctly", exc_info=module_not_found_error)
continue
# Ensure that what we've loaded is indeed a module
@ -146,5 +151,5 @@ def import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=Fa
# don't allow custom nodes to overwrite base nodes
custom_nodes_mappings -= base_and_extra
_comfy_nodes.update(base_and_extra + custom_nodes_mappings)
return _comfy_nodes
_nodes_available_at_startup.update(base_and_extra + custom_nodes_mappings)
return _nodes_available_at_startup

View File

@ -7,8 +7,6 @@ from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, T
from typing_extensions import TypedDict, NotRequired
from comfy.comfy_types import FileLocator
T = TypeVar('T')

View File

@ -131,6 +131,7 @@ dev = [
"freezegun",
"coverage",
"pylint",
"astroid",
]
[project.optional-dependencies]

View File

@ -24,18 +24,46 @@ class AbsoluteImportChecker(BaseChecker):
super().__init__(linter)
def visit_importfrom(self, node: nodes.ImportFrom) -> None:
current_file = node.root().file
if current_file is None:
"""
Check for absolute imports from the same top-level package.
This method is called for every `from ... import ...` statement.
It checks if a module within 'comfy' or 'comfy_extras' packages
is using an absolute import from its own package, which should
be a relative import instead.
For example, inside `comfy/nodes/logic.py`, an import like
`from comfy.utils import some_function` will be flagged.
The preferred way would be `from ..utils import some_function`.
"""
# An import is relative if its level is greater than 0.
# e.g., from . import foo (level=1), from .. import bar (level=2)
# We only want to check absolute imports, so we skip relative ones.
if node.level and node.level > 0:
return
package_path = os.path.dirname(current_file)
package_name = os.path.basename(package_path)
# Get the fully qualified name of the module being linted.
# For a file at '.../comfy/nodes/common.py', this will be 'comfy.nodes.common'.
module_qname = node.root().qname()
if node.modname.startswith(package_name) and package_name in ['comfy', 'comfy_extras']:
import_parts = node.modname.split('.')
# `node.modname` is the module name in the `from` statement.
# For `from comfy.utils import x`, `modname` is `comfy.utils`.
imported_modname = node.modname
if not imported_modname:
return
if import_parts[0] == package_name:
self.add_message('absolute-import-used', node=node, args=(node.modname,))
# We are only interested in modules within 'comfy' or 'comfy_extras'.
# We determine this by looking at the first part of the qualified name.
current_top_package = module_qname.split('.')[0]
if current_top_package not in ['comfy', 'comfy_extras']:
return
imported_top_package = imported_modname.split('.')[0]
# If the top-level package of the imported module is the same as the
# current module's top-level package, it's an internal absolute import.
if imported_top_package == current_top_package:
self.add_message('absolute-import-used', node=node, args=(imported_modname,))
def register(linter: "PyLinter") -> None:

View File

@ -1,50 +1,25 @@
import sys
import time
import logging
import multiprocessing
import os
import pathlib
import subprocess
import sys
import time
import urllib
from typing import Tuple, List
import pytest
import requests
import socket
import subprocess
import urllib
from testcontainers.rabbitmq import RabbitMqContainer
from typing import Tuple, List
from comfy.cli_args_types import Configuration
logging.getLogger("pika").setLevel(logging.CRITICAL + 1)
logging.getLogger("aio_pika").setLevel(logging.CRITICAL + 1)
logging.getLogger("testcontainers.core.container").setLevel(logging.WARNING)
logging.getLogger("testcontainers.core.waiting_utils").setLevel(logging.WARNING)
# fixes issues with running the testcontainers rabbitmqcontainer on Windows
os.environ["TC_HOST"] = "localhost"
def get_lan_ip():
"""
Finds the host's IP address on the LAN it's connected to.
Returns:
str: The IP address of the host on the LAN.
"""
# Create a dummy socket
s = None
try:
# Connect to a dummy address (Here, Google's public DNS server)
# The actual connection is not made, but this allows finding out the LAN IP
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80))
ip = s.getsockname()[0]
finally:
if s is not None:
s.close()
return ip
def run_server(server_arguments: Configuration):
from comfy.cmd.main import main
from comfy.cli_args import args
@ -95,6 +70,11 @@ def has_gpu() -> bool:
@pytest.fixture(scope="module", autouse=False, params=["ThreadPoolExecutor", "ProcessPoolExecutor"])
def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers: int = 1):
from huggingface_hub import hf_hub_download
from testcontainers.rabbitmq import RabbitMqContainer
logging.getLogger("testcontainers.core.container").setLevel(logging.WARNING)
logging.getLogger("testcontainers.core.waiting_utils").setLevel(logging.WARNING)
hf_hub_download("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors")
hf_hub_download("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors")
@ -108,8 +88,8 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers
frontend_command = [
"comfyui",
"--listen=0.0.0.0",
"--port=9001",
"--listen=127.0.0.1",
"--port=19001",
"--cpu",
"--distributed-queue-frontend",
f"-w={str(tmp_path)}",
@ -122,7 +102,7 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers
for i in range(num_workers):
backend_command = [
"comfyui-worker",
f"--port={9002 + i}",
f"--port={19002 + i}",
f"-w={str(tmp_path)}",
f"--distributed-queue-connection-uri={connection_uri}",
f"--executor-factory={executor_factory}"
@ -130,7 +110,7 @@ def frontend_backend_worker_with_rabbitmq(request, tmp_path_factory, num_workers
processes_to_close.append(subprocess.Popen(backend_command, stdout=sys.stdout, stderr=sys.stderr))
try:
server_address = f"http://{get_lan_ip()}:9001"
server_address = f"http://127.0.0.1:19001"
start_time = time.time()
connected = False
while time.time() - start_time < 60:

View File

@ -15,7 +15,7 @@ from comfy.component_model.executor_types import SendSyncEvent, SendSyncData, Ex
DependencyCycleError
from comfy.distributed.server_stub import ServerStub
from comfy.execution_context import context_add_custom_nodes
from comfy.graph_utils import GraphBuilder, Node
from comfy_execution.graph_utils import GraphBuilder, Node
from comfy.nodes.package_typing import ExportedNodes
current_test_name = ContextVar('current_test_name', default=None)

View File

@ -1,4 +1,4 @@
from comfy.graph_utils import GraphBuilder, is_link
from comfy_execution.graph_utils import GraphBuilder, is_link
from comfy.graph import ExecutionBlocker
from .tools import VariantSupport

View File

@ -3,7 +3,7 @@ import time
import asyncio
from comfy.utils import ProgressBar
from .tools import VariantSupport
from comfy.graph_utils import GraphBuilder
from comfy_execution.graph_utils import GraphBuilder
from comfy.comfy_types.node_typing import ComfyNodeABC
from comfy.comfy_types import IO

View File

@ -1,4 +1,4 @@
from comfy.graph_utils import GraphBuilder
from comfy_execution.graph_utils import GraphBuilder
from .tools import VariantSupport
@VariantSupport()

View File

@ -0,0 +1,129 @@
"""Pylint checker for ensuring main_pre is imported first."""
from typing import TYPE_CHECKING, Optional, Union
import astroid
from astroid import nodes
from pylint.checkers import BaseChecker
if TYPE_CHECKING:
from pylint.lint import PyLinter
class MainPreImportOrderChecker(BaseChecker):
"""
Ensures that imports from 'comfy.cmd.main_pre' or similar setup modules
occur before any other relative imports or imports from the 'comfy' package.
This is important for code that relies on setup being performed by 'main_pre'
before other modules from the same package are imported.
"""
name = 'main-pre-import-order'
msgs = {
'W0002': (
'Setup import %s must be placed before other package imports like %s.',
'main-pre-import-not-first',
"To ensure necessary setup is performed, 'comfy.cmd.main_pre' or similar "
"setup imports must precede other relative imports or imports from the "
"'comfy' family of packages."
),
}
def _is_main_pre_import(self, stmt: Union[nodes.Import, nodes.ImportFrom]) -> bool:
"""Checks if an import statement is for 'comfy.cmd.main_pre'."""
if isinstance(stmt, nodes.Import):
for name, _ in stmt.names:
if name == 'comfy.cmd.main_pre' or name.startswith('comfy.cmd.main_pre.'):
return True
return False
if isinstance(stmt, nodes.ImportFrom):
qname: Optional[str] = None
if stmt.level == 0:
qname = stmt.modname
else:
try:
# Attempt to resolve the relative import to a fully qualified name
imported_module = stmt.do_import_module()
qname = imported_module.qname()
except astroid.AstroidError:
# Fallback for unresolved relative imports, check the literal module name
if stmt.modname and stmt.modname.endswith('.main_pre'):
return True
# Heuristic for `from ..cmd import main_pre` in `comfy/entrypoints/*`
if stmt.modname == 'cmd' and stmt.root().qname().startswith('comfy'):
for name, _ in stmt.names:
if name == 'main_pre':
return True
if not qname:
return False
# from comfy.cmd import main_pre
if qname == 'comfy.cmd':
for name, _ in stmt.names:
if name == 'main_pre':
return True
# from comfy.cmd.main_pre import ... OR from a.b.c.main_pre import ...
if qname == 'comfy.cmd.main_pre' or qname.endswith('.main_pre'):
return True
return False
def _is_other_relevant_import(self, stmt: Union[nodes.Import, nodes.ImportFrom]) -> bool:
"""
Checks if an import is a relative import or an import from
the 'comfy' package family, and is not a 'main_pre' import.
"""
if self._is_main_pre_import(stmt):
return False
if isinstance(stmt, nodes.ImportFrom):
if stmt.level and stmt.level > 0: # Any relative import
return True
if stmt.modname and stmt.modname.startswith('comfy'):
return True
if isinstance(stmt, nodes.Import):
for name, _ in stmt.names:
if name.startswith('comfy'):
return True
return False
def visit_module(self, node: nodes.Module) -> None:
"""Checks the order of imports within a module."""
imports = [
stmt for stmt in node.body
if isinstance(stmt, (nodes.Import, nodes.ImportFrom))
]
main_pre_import_node = None
for stmt in imports:
if self._is_main_pre_import(stmt):
main_pre_import_node = stmt
break
# If there's no main_pre import, there's nothing to check.
if not main_pre_import_node:
return
for stmt in imports:
# We only care about imports that appear before the main_pre import
if stmt.fromlineno >= main_pre_import_node.fromlineno:
break
if self._is_other_relevant_import(stmt):
self.add_message(
'main-pre-import-not-first',
node=main_pre_import_node,
args=(main_pre_import_node.as_string(), stmt.as_string())
)
return # Report once per file and exit to avoid spam
def register(linter: "PyLinter") -> None:
"""This function is required for a Pylint plugin."""
linter.register_checker(MainPreImportOrderChecker(linter))

View File

@ -0,0 +1,292 @@
import os
from unittest.mock import patch
import pytest
from aiohttp import web
from comfy.app.user_manager import UserManager
pytestmark = (
pytest.mark.asyncio
) # This applies the asyncio mark to all test functions in the module
@pytest.fixture
def user_manager(tmp_path):
um = UserManager()
um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join(
tmp_path, file
) if file else tmp_path
return um
@pytest.fixture
def app(user_manager):
app = web.Application()
routes = web.RouteTableDef()
user_manager.add_routes(routes)
app.add_routes(routes)
return app
async def test_listuserdata_empty_directory(aiohttp_client, app, tmp_path):
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir")
assert resp.status == 404
async def test_listuserdata_with_files(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir")
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir")
assert resp.status == 200
assert await resp.json() == ["file1.txt"]
async def test_listuserdata_recursive(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir" / "subdir")
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
f.write("test content")
with open(tmp_path / "test_dir" / "subdir" / "file2.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&recurse=true")
assert resp.status == 200
assert set(await resp.json()) == {"file1.txt", "subdir/file2.txt"}
async def test_listuserdata_full_info(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir")
with open(tmp_path / "test_dir" / "file1.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&full_info=true")
assert resp.status == 200
result = await resp.json()
assert len(result) == 1
assert result[0]["path"] == "file1.txt"
assert "size" in result[0]
assert "modified" in result[0]
async def test_listuserdata_split_path(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir" / "subdir")
with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&recurse=true&split=true")
assert resp.status == 200
assert await resp.json() == [["subdir/file1.txt", "subdir", "file1.txt"]]
async def test_listuserdata_invalid_directory(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=")
assert resp.status == 400
async def test_listuserdata_normalized_separator(aiohttp_client, app, tmp_path):
os_sep = "\\"
with patch("os.sep", os_sep):
with patch("os.path.sep", os_sep):
os.makedirs(tmp_path / "test_dir" / "subdir")
with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.get("/userdata?dir=test_dir&recurse=true")
assert resp.status == 200
result = await resp.json()
assert len(result) == 1
assert "/" in result[0] # Ensure forward slash is used
assert "\\" not in result[0] # Ensure backslash is not present
assert result[0] == "subdir/file1.txt"
# Test with full_info
resp = await client.get(
"/userdata?dir=test_dir&recurse=true&full_info=true"
)
assert resp.status == 200
result = await resp.json()
assert len(result) == 1
assert "/" in result[0]["path"] # Ensure forward slash is used
assert "\\" not in result[0]["path"] # Ensure backslash is not present
assert result[0]["path"] == "subdir/file1.txt"
async def test_post_userdata_new_file(aiohttp_client, app, tmp_path):
client = await aiohttp_client(app)
content = b"test content"
resp = await client.post("/userdata/test.txt", data=content)
assert resp.status == 200
assert await resp.text() == '"test.txt"'
# Verify file was created with correct content
with open(tmp_path / "test.txt", "rb") as f:
assert f.read() == content
async def test_post_userdata_overwrite_existing(aiohttp_client, app, tmp_path):
# Create initial file
with open(tmp_path / "test.txt", "w") as f:
f.write("initial content")
client = await aiohttp_client(app)
new_content = b"updated content"
resp = await client.post("/userdata/test.txt", data=new_content)
assert resp.status == 200
assert await resp.text() == '"test.txt"'
# Verify file was overwritten
with open(tmp_path / "test.txt", "rb") as f:
assert f.read() == new_content
async def test_post_userdata_no_overwrite(aiohttp_client, app, tmp_path):
# Create initial file
with open(tmp_path / "test.txt", "w") as f:
f.write("initial content")
client = await aiohttp_client(app)
resp = await client.post("/userdata/test.txt?overwrite=false", data=b"new content")
assert resp.status == 409
# Verify original content unchanged
with open(tmp_path / "test.txt", "r") as f:
assert f.read() == "initial content"
async def test_post_userdata_full_info(aiohttp_client, app, tmp_path):
client = await aiohttp_client(app)
content = b"test content"
resp = await client.post("/userdata/test.txt?full_info=true", data=content)
assert resp.status == 200
result = await resp.json()
assert result["path"] == "test.txt"
assert result["size"] == len(content)
assert "modified" in result
async def test_move_userdata(aiohttp_client, app, tmp_path):
# Create initial file
with open(tmp_path / "source.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.post("/userdata/source.txt/move/dest.txt")
assert resp.status == 200
assert await resp.text() == '"dest.txt"'
# Verify file was moved
assert not os.path.exists(tmp_path / "source.txt")
with open(tmp_path / "dest.txt", "r") as f:
assert f.read() == "test content"
async def test_move_userdata_no_overwrite(aiohttp_client, app, tmp_path):
# Create source and destination files
with open(tmp_path / "source.txt", "w") as f:
f.write("source content")
with open(tmp_path / "dest.txt", "w") as f:
f.write("destination content")
client = await aiohttp_client(app)
resp = await client.post("/userdata/source.txt/move/dest.txt?overwrite=false")
assert resp.status == 409
# Verify files remain unchanged
with open(tmp_path / "source.txt", "r") as f:
assert f.read() == "source content"
with open(tmp_path / "dest.txt", "r") as f:
assert f.read() == "destination content"
async def test_move_userdata_full_info(aiohttp_client, app, tmp_path):
# Create initial file
with open(tmp_path / "source.txt", "w") as f:
f.write("test content")
client = await aiohttp_client(app)
resp = await client.post("/userdata/source.txt/move/dest.txt?full_info=true")
assert resp.status == 200
result = await resp.json()
assert result["path"] == "dest.txt"
assert result["size"] == len("test content")
assert "modified" in result
# Verify file was moved
assert not os.path.exists(tmp_path / "source.txt")
with open(tmp_path / "dest.txt", "r") as f:
assert f.read() == "test content"
async def test_listuserdata_v2_empty_root(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.get("/v2/userdata")
assert resp.status == 200
assert await resp.json() == []
async def test_listuserdata_v2_nonexistent_subdirectory(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.get("/v2/userdata?path=does_not_exist")
assert resp.status == 404
async def test_listuserdata_v2_default(aiohttp_client, app, tmp_path):
os.makedirs(tmp_path / "test_dir" / "subdir")
(tmp_path / "test_dir" / "file1.txt").write_text("content")
(tmp_path / "test_dir" / "subdir" / "file2.txt").write_text("content")
client = await aiohttp_client(app)
resp = await client.get("/v2/userdata?path=test_dir")
assert resp.status == 200
data = await resp.json()
file_paths = {item["path"] for item in data if item["type"] == "file"}
assert file_paths == {"test_dir/file1.txt", "test_dir/subdir/file2.txt"}
async def test_listuserdata_v2_normalized_separators(aiohttp_client, app, tmp_path, monkeypatch):
# Force backslash as os separator
monkeypatch.setattr(os, 'sep', '\\')
monkeypatch.setattr(os.path, 'sep', '\\')
os.makedirs(tmp_path / "test_dir" / "subdir")
(tmp_path / "test_dir" / "subdir" / "file1.txt").write_text("x")
client = await aiohttp_client(app)
resp = await client.get("/v2/userdata?path=test_dir")
assert resp.status == 200
data = await resp.json()
for item in data:
assert "/" in item["path"]
assert "\\" not in item["path"]
async def test_listuserdata_v2_url_encoded_path(aiohttp_client, app, tmp_path):
# Create a directory with a space in its name and a file inside
os.makedirs(tmp_path / "my dir")
(tmp_path / "my dir" / "file.txt").write_text("content")
client = await aiohttp_client(app)
# Use URL-encoded space in path parameter
resp = await client.get("/v2/userdata?path=my%20dir&recurse=false")
assert resp.status == 200
data = await resp.json()
assert len(data) == 1
entry = data[0]
assert entry["name"] == "file.txt"
# Ensure the path is correctly decoded and uses forward slash
assert entry["path"] == "my dir/file.txt"

View File

@ -122,7 +122,7 @@ def test_bool_request_parameter():
assert v == True
def test_string_enum_request_parameter():
async def test_string_enum_request_parameter():
nt = StringEnumRequestParameter.INPUT_TYPES()
assert nt is not None
n = StringEnumRequestParameter()
@ -155,8 +155,9 @@ def test_string_enum_request_parameter():
}
from comfy.cmd.execution import validate_inputs
validated: dict[str, ValidateInputsTuple] = {}
validated["1"] = validate_inputs(prompt, "1", validated)
validated["2"] = validate_inputs(prompt, "2", validated)
prompt_id = str(uuid.uuid4())
validated["1"] = await validate_inputs(prompt_id, prompt, "1", validated)
validated["2"] = await validate_inputs(prompt_id, prompt, "2", validated)
assert validated["2"].valid