diff --git a/README.md b/README.md index 347fcb520..4966f879f 100644 --- a/README.md +++ b/README.md @@ -722,19 +722,31 @@ You can pass additional extra model path configurations with one or more copies ### Command Line Arguments ``` -usage: comfyui.exe [-h] [-c CONFIG_FILE] [--write-out-config-file CONFIG_OUTPUT_PATH] [-w CWD] [-H [IP]] [--port PORT] [--enable-cors-header [ORIGIN]] [--max-upload-size MAX_UPLOAD_SIZE] - [--extra-model-paths-config PATH [PATH ...]] [--output-directory OUTPUT_DIRECTORY] [--temp-directory TEMP_DIRECTORY] [--input-directory INPUT_DIRECTORY] [--auto-launch] - [--disable-auto-launch] [--cuda-device DEVICE_ID] [--cuda-malloc | --disable-cuda-malloc] [--force-fp32 | --force-fp16 | --force-bf16] +usage: comfyui.exe [-h] [-c CONFIG_FILE] [--write-out-config-file CONFIG_OUTPUT_PATH] [-w CWD] [--base-paths BASE_PATHS [BASE_PATHS ...]] [-H [IP]] [--port PORT] + [--enable-cors-header [ORIGIN]] [--max-upload-size MAX_UPLOAD_SIZE] [--extra-model-paths-config PATH [PATH ...]] + [--output-directory OUTPUT_DIRECTORY] [--temp-directory TEMP_DIRECTORY] [--input-directory INPUT_DIRECTORY] [--auto-launch] [--disable-auto-launch] + [--cuda-device DEVICE_ID] [--cuda-malloc | --disable-cuda-malloc] [--force-fp32 | --force-fp16 | --force-bf16] [--bf16-unet | --fp16-unet | --fp8_e4m3fn-unet | --fp8_e5m2-unet] [--fp16-vae | --fp32-vae | --bf16-vae] [--cpu-vae] [--fp8_e4m3fn-text-enc | --fp8_e5m2-text-enc | --fp16-text-enc | --fp32-text-enc] [--directml [DIRECTML_DEVICE]] [--disable-ipex-optimize] - [--preview-method [none,auto,latent2rgb,taesd]] [--use-split-cross-attention | --use-quad-cross-attention | --use-pytorch-cross-attention] [--disable-xformers] - [--force-upcast-attention | --dont-upcast-attention] [--gpu-only | --highvram | --normalvram | --lowvram | --novram | --cpu] [--disable-smart-memory] [--deterministic] - [--dont-print-server] [--quick-test-for-ci] [--windows-standalone-build] [--disable-metadata] [--multi-user] [--create-directories] - [--plausible-analytics-base-url PLAUSIBLE_ANALYTICS_BASE_URL] [--plausible-analytics-domain PLAUSIBLE_ANALYTICS_DOMAIN] [--analytics-use-identity-provider] - [--distributed-queue-connection-uri DISTRIBUTED_QUEUE_CONNECTION_URI] [--distributed-queue-worker] [--distributed-queue-frontend] [--distributed-queue-name DISTRIBUTED_QUEUE_NAME] - [--external-address EXTERNAL_ADDRESS] [--verbose] [--disable-known-models] [--max-queue-size MAX_QUEUE_SIZE] [--otel-service-name OTEL_SERVICE_NAME] - [--otel-service-version OTEL_SERVICE_VERSION] [--otel-exporter-otlp-endpoint OTEL_EXPORTER_OTLP_ENDPOINT] - + [--preview-method [none,auto,latent2rgb,taesd]] [--preview-size PREVIEW_SIZE] [--cache-lru CACHE_LRU] + [--use-split-cross-attention | --use-quad-cross-attention | --use-pytorch-cross-attention] [--disable-xformers] [--disable-flash-attn] + [--disable-sage-attention] [--force-upcast-attention | --dont-upcast-attention] + [--gpu-only | --highvram | --normalvram | --lowvram | --novram | --cpu] [--reserve-vram RESERVE_VRAM] + [--default-hashing-function {md5,sha1,sha256,sha512}] [--disable-smart-memory] [--deterministic] [--fast] [--dont-print-server] + [--quick-test-for-ci] [--windows-standalone-build] [--disable-metadata] [--disable-all-custom-nodes] [--multi-user] [--create-directories] + [--plausible-analytics-base-url PLAUSIBLE_ANALYTICS_BASE_URL] [--plausible-analytics-domain PLAUSIBLE_ANALYTICS_DOMAIN] + [--analytics-use-identity-provider] [--distributed-queue-connection-uri DISTRIBUTED_QUEUE_CONNECTION_URI] [--distributed-queue-worker] + [--distributed-queue-frontend] [--distributed-queue-name DISTRIBUTED_QUEUE_NAME] [--external-address EXTERNAL_ADDRESS] + [--logging-level {DEBUG,INFO,WARNING,ERROR,CRITICAL}] [--disable-known-models] [--max-queue-size MAX_QUEUE_SIZE] + [--otel-service-name OTEL_SERVICE_NAME] [--otel-service-version OTEL_SERVICE_VERSION] [--otel-exporter-otlp-endpoint OTEL_EXPORTER_OTLP_ENDPOINT] + [--force-channels-last] [--force-hf-local-dir-mode] [--front-end-version FRONT_END_VERSION] [--front-end-root FRONT_END_ROOT] + [--executor-factory EXECUTOR_FACTORY] [--openai-api-key OPENAI_API_KEY] [--user-directory USER_DIRECTORY] [--blip-model-url BLIP_MODEL_URL] + [--blip-model-vqa-url BLIP_MODEL_VQA_URL] [--sam-model-vith-url SAM_MODEL_VITH_URL] [--sam-model-vitl-url SAM_MODEL_VITL_URL] + [--sam-model-vitb-url SAM_MODEL_VITB_URL] [--history-display-limit HISTORY_DISPLAY_LIMIT] [--ffmpeg-bin-path FFMPEG_BIN_PATH] + [--ffmpeg-extra-codecs FFMPEG_EXTRA_CODECS] [--wildcards-path WILDCARDS_PATH] [--wildcard-api WILDCARD_API] [--photoprism-host PHOTOPRISM_HOST] + [--immich-host IMMICH_HOST] [--ideogram-session-cookie IDEOGRAM_SESSION_COOKIE] [--annotator-ckpts-path ANNOTATOR_CKPTS_PATH] [--use-symlinks] + [--ort-providers ORT_PROVIDERS] [--vfi-ops-backend VFI_OPS_BACKEND] [--dependency-version DEPENDENCY_VERSION] [--mmdet-skip] [--sam-editor-cpu] + [--sam-editor-model SAM_EDITOR_MODEL] [--custom-wildcards CUSTOM_WILDCARDS] [--disable-gpu-opencv] options: -h, --help show this help message and exit @@ -742,10 +754,14 @@ options: config file path --write-out-config-file CONFIG_OUTPUT_PATH takes the current command line args and writes them out to a config file at the given path, then exits - -w CWD, --cwd CWD Specify the working directory. If not set, this is the current working directory. models/, input/, output/ and other directories will be located here by default. [env var: - COMFYUI_CWD] + -w CWD, --cwd CWD Specify the working directory. If not set, this is the current working directory. models/, input/, output/ and other directories will be + located here by default. [env var: COMFYUI_CWD] + --base-paths BASE_PATHS [BASE_PATHS ...] + Additional base paths for custom nodes, models and inputs. [env var: COMFYUI_BASE_PATHS] -H [IP], --listen [IP] - Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all) [env var: COMFYUI_LISTEN] + Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: + 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6) [env var: + COMFYUI_LISTEN] --port PORT Set the listen port. [env var: COMFYUI_PORT] --enable-cors-header [ORIGIN] Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'. [env var: COMFYUI_ENABLE_CORS_HEADER] @@ -789,6 +805,10 @@ options: Disables ipex.optimize when loading models with Intel GPUs. [env var: COMFYUI_DISABLE_IPEX_OPTIMIZE] --preview-method [none,auto,latent2rgb,taesd] Default preview method for sampler nodes. [env var: COMFYUI_PREVIEW_METHOD] + --preview-size PREVIEW_SIZE + Sets the maximum preview size for sampler nodes. [env var: COMFYUI_PREVIEW_SIZE] + --cache-lru CACHE_LRU + Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM. [env var: COMFYUI_CACHE_LRU] --use-split-cross-attention Use the split cross attention optimization. Ignored when xformers is used. [env var: COMFYUI_USE_SPLIT_CROSS_ATTENTION] --use-quad-cross-attention @@ -796,6 +816,9 @@ options: --use-pytorch-cross-attention Use the new pytorch 2.0 cross attention function. [env var: COMFYUI_USE_PYTORCH_CROSS_ATTENTION] --disable-xformers Disable xformers. [env var: COMFYUI_DISABLE_XFORMERS] + --disable-flash-attn Disable Flash Attention [env var: COMFYUI_DISABLE_FLASH_ATTN] + --disable-sage-attention + Disable Sage Attention [env var: COMFYUI_DISABLE_SAGE_ATTENTION] --force-upcast-attention Force enable attention upcasting, please report if it fixes black images. [env var: COMFYUI_FORCE_UPCAST_ATTENTION] --dont-upcast-attention @@ -806,15 +829,25 @@ options: --lowvram Split the unet in parts to use less vram. [env var: COMFYUI_LOWVRAM] --novram When lowvram isn't enough. [env var: COMFYUI_NOVRAM] --cpu To use the CPU for everything (slow). [env var: COMFYUI_CPU] + --reserve-vram RESERVE_VRAM + Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS. + [env var: COMFYUI_RESERVE_VRAM] + --default-hashing-function {md5,sha1,sha256,sha512} + Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256. [env var: + COMFYUI_DEFAULT_HASHING_FUNCTION] --disable-smart-memory Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can. [env var: COMFYUI_DISABLE_SMART_MEMORY] - --deterministic Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases. [env var: COMFYUI_DETERMINISTIC] + --deterministic Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases. [env var: + COMFYUI_DETERMINISTIC] + --fast Enable some untested and potentially quality deteriorating optimizations. [env var: COMFYUI_FAST] --dont-print-server Don't print server output. [env var: COMFYUI_DONT_PRINT_SERVER] --quick-test-for-ci Quick test for CI. Raises an error if nodes cannot be imported, [env var: COMFYUI_QUICK_TEST_FOR_CI] --windows-standalone-build - Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup). [env var: - COMFYUI_WINDOWS_STANDALONE_BUILD] + Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening + the page on startup). [env var: COMFYUI_WINDOWS_STANDALONE_BUILD] --disable-metadata Disable saving prompt metadata in files. [env var: COMFYUI_DISABLE_METADATA] + --disable-all-custom-nodes + Disable loading all custom nodes. [env var: COMFYUI_DISABLE_ALL_CUSTOM_NODES] --multi-user Enables per-user storage. [env var: COMFYUI_MULTI_USER] --create-directories Creates the default models/, input/, output/ and temp/ directories, then exits. [env var: COMFYUI_CREATE_DIRECTORIES] --plausible-analytics-base-url PLAUSIBLE_ANALYTICS_BASE_URL @@ -824,18 +857,19 @@ options: --analytics-use-identity-provider Uses platform identifiers for unique visitor analytics. [env var: COMFYUI_ANALYTICS_USE_IDENTITY_PROVIDER] --distributed-queue-connection-uri DISTRIBUTED_QUEUE_CONNECTION_URI - EXAMPLE: "amqp://guest:guest@127.0.0.1" - Servers and clients will connect to this AMPQ URL to form a distributed queue and exchange prompt execution requests and progress - updates. [env var: COMFYUI_DISTRIBUTED_QUEUE_CONNECTION_URI] + EXAMPLE: "amqp://guest:guest@127.0.0.1" - Servers and clients will connect to this AMPQ URL to form a distributed queue and exchange prompt + execution requests and progress updates. [env var: COMFYUI_DISTRIBUTED_QUEUE_CONNECTION_URI] --distributed-queue-worker Workers will pull requests off the AMQP URL. [env var: COMFYUI_DISTRIBUTED_QUEUE_WORKER] --distributed-queue-frontend Frontends will start the web UI and connect to the provided AMQP URL to submit prompts. [env var: COMFYUI_DISTRIBUTED_QUEUE_FRONTEND] --distributed-queue-name DISTRIBUTED_QUEUE_NAME - This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the - user ID [env var: COMFYUI_DISTRIBUTED_QUEUE_NAME] + This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue + name, followed by a '.', then the user ID [env var: COMFYUI_DISTRIBUTED_QUEUE_NAME] --external-address EXTERNAL_ADDRESS Specifies a base URL for external addresses reported by the API, such as for image paths. [env var: COMFYUI_EXTERNAL_ADDRESS] - --verbose Enables more debug prints. [env var: COMFYUI_VERBOSE] + --logging-level {DEBUG,INFO,WARNING,ERROR,CRITICAL} + Set the logging level [env var: COMFYUI_LOGGING_LEVEL] --disable-known-models Disables automatic downloads of known models and prevents them from appearing in the UI. [env var: COMFYUI_DISABLE_KNOWN_MODELS] --max-queue-size MAX_QUEUE_SIZE @@ -845,8 +879,27 @@ options: --otel-service-version OTEL_SERVICE_VERSION The version of the service or application that is generating telemetry data. [env var: OTEL_SERVICE_VERSION] --otel-exporter-otlp-endpoint OTEL_EXPORTER_OTLP_ENDPOINT - A base endpoint URL for any signal type, with an optionally-specified port number. Helpful for when you're sending more than one signal to the same endpoint and want one - environment variable to control the endpoint. [env var: OTEL_EXPORTER_OTLP_ENDPOINT] + A base endpoint URL for any signal type, with an optionally-specified port number. Helpful for when you're sending more than one signal to the + same endpoint and want one environment variable to control the endpoint. [env var: OTEL_EXPORTER_OTLP_ENDPOINT] + --force-channels-last + Force channels last format when inferencing the models. [env var: COMFYUI_FORCE_CHANNELS_LAST] + --force-hf-local-dir-mode + Download repos from huggingface.co to the models/huggingface directory with the "local_dir" argument instead of models/huggingface_cache with + the "cache_dir" argument, recreating the traditional file structure. [env var: COMFYUI_FORCE_HF_LOCAL_DIR_MODE] + --front-end-version FRONT_END_VERSION + Specifies the version of the frontend to be used. This command needs internet connectivity to query and download available frontend + implementations from GitHub releases. The version string should be in the format of: [repoOwner]/[repoName]@[version] where version is one of: + "latest" or a valid version number (e.g. "1.0.0") [env var: COMFYUI_FRONT_END_VERSION] + --front-end-root FRONT_END_ROOT + The local filesystem path to the directory where the frontend is located. Overrides --front-end-version. [env var: COMFYUI_FRONT_END_ROOT] + --executor-factory EXECUTOR_FACTORY + When running ComfyUI as a distributed worker, this specifies the kind of executor that should be used to run the actual ComfyUI workflow + worker. A ThreadPoolExecutor is the default. A ProcessPoolExecutor results in better memory management, since the process will be closed and + large, contiguous blocks of CUDA memory can be freed. [env var: COMFYUI_EXECUTOR_FACTORY] + --openai-api-key OPENAI_API_KEY + Configures the OpenAI API Key for the OpenAI nodes [env var: OPENAI_API_KEY] + --user-directory USER_DIRECTORY + Set the ComfyUI user directory with an absolute path. [env var: COMFYUI_USER_DIRECTORY] Args that start with '--' can also be set in a config file (config.yaml or config.json or specified via -c). Config file syntax allows: key=value, flag=true, stuff=[a,b,c] (for details, see syntax at https://goo.gl/R74nmi). In general, command-line values override environment variables which override config file values which override defaults. diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 49e293c83..5933327d7 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -28,6 +28,7 @@ def _create_parser() -> EnhancedConfigArgParser: parser.add_argument('-w', "--cwd", type=str, default=None, help="Specify the working directory. If not set, this is the current working directory. models/, input/, output/ and other directories will be located here by default.") + parser.add_argument("--base-paths", type=str, nargs='+', default=[], help="Additional base paths for custom nodes, models and inputs.") parser.add_argument('-H', "--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)") parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") @@ -161,7 +162,7 @@ def _create_parser() -> EnhancedConfigArgParser: help="This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID") parser.add_argument("--external-address", required=False, help="Specifies a base URL for external addresses reported by the API, such as for image paths.") - parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level') + parser.add_argument("--logging-level", type=str, default='INFO', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level') parser.add_argument("--disable-known-models", action="store_true", help="Disables automatic downloads of known models and prevents them from appearing in the UI.") parser.add_argument("--max-queue-size", type=int, default=65536, help="The API will reject prompt requests if the queue's size exceeds this value.") # tracing @@ -251,11 +252,6 @@ def _parse_args(parser: Optional[argparse.ArgumentParser] = None, args_parsing: if args.disable_auto_launch: args.auto_launch = False - logging_level = logging.INFO - if args.verbose: - logging_level = logging.DEBUG - - logging.basicConfig(format="%(message)s", level=logging_level) configuration_obj = Configuration(**vars(args)) configuration_obj.config_files = config_files assert all(isinstance(config_file, str) for config_file in config_files) diff --git a/comfy/cli_args_types.py b/comfy/cli_args_types.py index b8e83b2a3..95409ee45 100644 --- a/comfy/cli_args_types.py +++ b/comfy/cli_args_types.py @@ -1,6 +1,8 @@ from __future__ import annotations +import collections import enum +from pathlib import Path from typing import Optional, List, Callable, Any, Union, Mapping, NamedTuple import configargparse @@ -36,15 +38,16 @@ class Configuration(dict): Attributes: config_files (Optional[List[str]]): Path to the configuration file(s) that were set in the arguments. - cwd (Optional[str]): Working directory. Defaults to the current directory. + cwd (Optional[str]): Working directory. Defaults to the current directory. This is always treated as a base path for model files, and it will be the place where model files are downloaded. + base_paths (Optional[list[str]]): Additional base paths for custom nodes, models and inputs. listen (str): IP address to listen on. Defaults to "127.0.0.1". port (int): Port number for the server to listen on. Defaults to 8188. enable_cors_header (Optional[str]): Enables CORS with the specified origin. max_upload_size (float): Maximum upload size in MB. Defaults to 100. extra_model_paths_config (Optional[List[str]]): Extra model paths configuration files. - output_directory (Optional[str]): Directory for output files. + output_directory (Optional[str]): Directory for output files. This can also be a relative path to the cwd or current working directory. temp_directory (Optional[str]): Temporary directory for processing. - input_directory (Optional[str]): Directory for input files. + input_directory (Optional[str]): Directory for input files. When this is a relative path, it will be looked up relative to the cwd (current working directory) and all of the base_paths. auto_launch (bool): Auto-launch UI in the default browser. Defaults to False. disable_auto_launch (bool): Disable auto-launching the browser. cuda_device (Optional[int]): CUDA device ID. None means default device. @@ -87,7 +90,6 @@ class Configuration(dict): reserve_vram (Optional[float]): Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS disable_smart_memory (bool): Disable smart memory management. deterministic (bool): Use deterministic algorithms where possible. - dont_print_server (bool): Suppress server output. quick_test_for_ci (bool): Enable quick testing mode for CI. windows_standalone_build (bool): Enable features for standalone Windows build. disable_metadata (bool): Disable saving metadata with outputs. @@ -103,7 +105,7 @@ class Configuration(dict): distributed_queue_worker (bool): Workers will pull requests off the AMQP URL. distributed_queue_name (str): This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID. external_address (str): Specifies a base URL for external addresses reported by the API, such as for image paths. - verbose (bool | str): Shows extra output for debugging purposes such as import errors of custom nodes; or, specifies a log level + logging_level (str): Specifies a log level disable_known_models (bool): Disables automatic downloads of known models and prevents them from appearing in the UI. max_queue_size (int): The API will reject prompt requests if the queue's size exceeds this value. otel_service_name (str): The name of the service or application that is generating telemetry data. Default: "comfyui". @@ -122,6 +124,7 @@ class Configuration(dict): self._observers: List[ConfigObserver] = [] self.config_files = [] self.cwd: Optional[str] = None + self.base_paths: list[Path] = [] self.listen: str = "127.0.0.1" self.port: int = 8188 self.enable_cors_header: Optional[str] = None @@ -192,7 +195,7 @@ class Configuration(dict): self.force_channels_last: bool = False self.force_hf_local_dir_mode = False self.preview_size: int = 512 - self.verbose: str | bool = "INFO" + self.logging_level: str = "INFO" # from guill self.cache_lru: int = 0 @@ -253,6 +256,17 @@ class Configuration(dict): self.update(state) self._observers = [] + @property + def verbose(self) -> str: + return self.logging_level + + @verbose.setter + def verbose(self, value): + if isinstance(value, bool): + self.logging_level = "DEBUG" + else: + self.logging_level = value + class EnumAction(argparse.Action): """ diff --git a/comfy/client/embedded_comfy_client.py b/comfy/client/embedded_comfy_client.py index 1789709d9..df6473663 100644 --- a/comfy/client/embedded_comfy_client.py +++ b/comfy/client/embedded_comfy_client.py @@ -1,13 +1,14 @@ from __future__ import annotations import asyncio -import contextvars import gc import json +import os +import threading import uuid from asyncio import get_event_loop -from concurrent.futures import ThreadPoolExecutor from multiprocessing import RLock +from pathlib import Path from typing import Optional from opentelemetry import context, propagate @@ -17,13 +18,16 @@ from opentelemetry.trace import Status, StatusCode from .client_types import V1QueuePromptResponse from ..api.components.schema.prompt import PromptDict from ..cli_args_types import Configuration +from ..cmd.folder_paths import init_default_paths from ..cmd.main_pre import tracer -from ..component_model.executor_types import ExecutorToClientProgress, Executor +from ..component_model.executor_types import ExecutorToClientProgress from ..component_model.make_mutable import make_mutable +from ..distributed.executors import ContextVarExecutor from ..distributed.process_pool_executor import ProcessPoolExecutor from ..distributed.server_stub import ServerStub +from ..execution_context import current_execution_context -_prompt_executor = contextvars.ContextVar('prompt_executor') +_prompt_executor = threading.local() def _execute_prompt( @@ -33,6 +37,9 @@ def _execute_prompt( span_context: dict, progress_handler: ExecutorToClientProgress | None, configuration: Configuration | None) -> dict: + execution_context = current_execution_context() + if len(execution_context.folder_names_and_paths) == 0 or configuration is not None: + init_default_paths(execution_context.folder_names_and_paths, configuration) span_context: Context = propagate.extract(span_context) token = attach(span_context) try: @@ -52,8 +59,8 @@ def __execute_prompt( progress_handler = progress_handler or ServerStub() try: - prompt_executor = _prompt_executor.get() - except LookupError: + prompt_executor: PromptExecutor = _prompt_executor.executor + except (LookupError, AttributeError): if configuration is None: options.enable_args_parsing() else: @@ -65,7 +72,7 @@ def __execute_prompt( with tracer.start_as_current_span("Initialize Prompt Executor", context=span_context) as span: prompt_executor = PromptExecutor(progress_handler, lru_size=configuration.cache_lru if configuration is not None else 0) prompt_executor.raise_exceptions = True - _prompt_executor.set(prompt_executor) + _prompt_executor.executor = prompt_executor with tracer.start_as_current_span("Execute Prompt", context=span_context) as span: try: @@ -96,6 +103,13 @@ def __execute_prompt( def _cleanup(): + from ..cmd.execution import PromptExecutor + try: + prompt_executor: PromptExecutor = _prompt_executor.executor + # this should clear all references to output tensors and make it easier to collect back the memory + prompt_executor.reset() + except (LookupError, AttributeError): + pass from .. import model_management model_management.unload_all_models() gc.collect() @@ -139,9 +153,9 @@ class EmbeddedComfyClient: In order to use this in blocking methods, learn more about asyncio online. """ - def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: Executor = None): + def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: ProcessPoolExecutor | ContextVarExecutor = None): self._progress_handler = progress_handler or ServerStub() - self._executor = executor or ThreadPoolExecutor(max_workers=max_workers) + self._executor = executor or ContextVarExecutor(max_workers=max_workers) self._configuration = configuration self._is_running = False self._task_count_lock = RLock() diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index 27c852b39..087d8f662 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -27,7 +27,7 @@ from ..component_model.executor_types import ExecutorToClientProgress, Validatio HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage from ..component_model.files import canonicalize_path from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus -from ..execution_context import new_execution_context, context_execute_node, ExecutionContext +from ..execution_context import context_execute_node, context_execute_prompt from ..nodes.package import import_all_nodes_in_workspace from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode @@ -77,24 +77,19 @@ class IsChangedCache: class CacheSet: def __init__(self, lru_size=None): if lru_size is None or lru_size == 0: - self.init_classic_cache() + # Performs like the old cache -- dump data ASAP + + self.outputs = HierarchicalCache(CacheKeySetInputSignature) + self.ui = HierarchicalCache(CacheKeySetInputSignature) + self.objects = HierarchicalCache(CacheKeySetID) else: - self.init_lru_cache(lru_size) + # Useful for those with ample RAM/VRAM -- allows experimenting without + # blowing away the cache every time + self.outputs = LRUCache(CacheKeySetInputSignature, max_size=lru_size) + self.ui = LRUCache(CacheKeySetInputSignature, max_size=lru_size) + self.objects = HierarchicalCache(CacheKeySetID) self.all = [self.outputs, self.ui, self.objects] - # Useful for those with ample RAM/VRAM -- allows experimenting without - # blowing away the cache every time - def init_lru_cache(self, cache_size): - self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) - self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size) - self.objects = HierarchicalCache(CacheKeySetID) - - # Performs like the old cache -- dump data ASAP - def init_classic_cache(self): - self.outputs = HierarchicalCache(CacheKeySetInputSignature) - self.ui = HierarchicalCache(CacheKeySetInputSignature) - self.objects = HierarchicalCache(CacheKeySetID) - def recursive_debug_dump(self): result = { "outputs": self.outputs.recursive_debug_dump(), @@ -308,11 +303,11 @@ def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches, :param pending_subgraph_results: :return: """ - with context_execute_node(_node_id, prompt_id): + with context_execute_node(_node_id): return _execute(server, dynprompt, caches, _node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) -def _execute(server, dynprompt, caches, current_item: str, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) -> RecursiveExecutionTuple: +def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) -> RecursiveExecutionTuple: unique_id = current_item real_node_id = dynprompt.get_real_node_id(unique_id) display_node_id = dynprompt.get_display_node_id(unique_id) @@ -548,7 +543,7 @@ class PromptExecutor: # torchao and potentially other optimization approaches break when the models are created in inference mode # todo: this should really be backpropagated to code which creates ModelPatchers via lazy evaluation rather than globally checked here inference_mode = all(not hasattr(node_class, "INFERENCE_MODE") or node_class.INFERENCE_MODE for node_class in iterate_obj_classes(prompt)) - with new_execution_context(ExecutionContext(self.server, task_id=prompt_id, inference_mode=inference_mode)): + with context_execute_prompt(self.server, prompt_id, inference_mode=inference_mode): self._execute_inner(prompt, prompt_id, extra_data, execute_outputs) def _execute_inner(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None, inference_mode: bool = True): diff --git a/comfy/cmd/folder_paths.py b/comfy/cmd/folder_paths.py index ee09a19a9..b0b4ce9c3 100644 --- a/comfy/cmd/folder_paths.py +++ b/comfy/cmd/folder_paths.py @@ -4,142 +4,144 @@ import logging import mimetypes import os import time -from typing import Optional, List, Final, Literal +from contextlib import nullcontext +from pathlib import Path +from typing import Optional, List, Literal -from .folder_paths_pre import get_base_path +from ..cli_args_types import Configuration +from ..component_model.deprecation import _deprecate_method from ..component_model.files import get_package_as_path -from ..component_model.folder_path_types import FolderPathsTuple, FolderNames, SaveImagePathResponse -from ..component_model.folder_path_types import extension_mimetypes_cache as _extension_mimetypes_cache -from ..component_model.folder_path_types import supported_pt_extensions as _supported_pt_extensions -from ..component_model.module_property import module_property +from ..component_model.folder_path_types import FolderNames, SaveImagePathTuple, ModelPaths +from ..component_model.folder_path_types import supported_pt_extensions, extension_mimetypes_cache +from ..component_model.module_property import create_module_properties +from ..component_model.platform_path import construct_path +from ..execution_context import current_execution_context -supported_pt_extensions: Final[frozenset[str]] = _supported_pt_extensions -extension_mimetypes_cache: Final[dict[str, str]] = _extension_mimetypes_cache +_module_properties = create_module_properties() + + +@_module_properties.getter +def _supported_pt_extensions() -> frozenset[str]: + return supported_pt_extensions + + +@_module_properties.getter +def _extension_mimetypes_cache() -> dict[str, str]: + return extension_mimetypes_cache # todo: this needs to be wrapped in a context and configurable -@module_property +@_module_properties.getter def _base_path(): - return get_base_path() + return _folder_names_and_paths().base_paths[0] -models_dir = os.path.join(get_base_path(), "models") -folder_names_and_paths: Final[FolderNames] = FolderNames(models_dir) -folder_names_and_paths["checkpoints"] = FolderPathsTuple("checkpoints", [os.path.join(models_dir, "checkpoints")], set(supported_pt_extensions)) -folder_names_and_paths["configs"] = FolderPathsTuple("configs", [os.path.join(models_dir, "configs"), get_package_as_path("comfy.configs")], {".yaml"}) -folder_names_and_paths["loras"] = FolderPathsTuple("loras", [os.path.join(models_dir, "loras")], set(supported_pt_extensions)) -folder_names_and_paths["vae"] = FolderPathsTuple("vae", [os.path.join(models_dir, "vae")], set(supported_pt_extensions)) -folder_names_and_paths["clip"] = FolderPathsTuple("clip", [os.path.join(models_dir, "clip")], set(supported_pt_extensions)) -folder_names_and_paths["unet"] = folder_names_and_paths["diffusion_models"] = FolderPathsTuple("diffusion_models", [os.path.join(models_dir, "unet"), os.path.join(models_dir, "diffusion_models")], set(supported_pt_extensions)) -folder_names_and_paths["clip_vision"] = FolderPathsTuple("clip_vision", [os.path.join(models_dir, "clip_vision")], set(supported_pt_extensions)) -folder_names_and_paths["style_models"] = FolderPathsTuple("style_models", [os.path.join(models_dir, "style_models")], set(supported_pt_extensions)) -folder_names_and_paths["embeddings"] = FolderPathsTuple("embeddings", [os.path.join(models_dir, "embeddings")], set(supported_pt_extensions)) -folder_names_and_paths["diffusers"] = FolderPathsTuple("diffusers", [os.path.join(models_dir, "diffusers")], {"folder"}) -folder_names_and_paths["vae_approx"] = FolderPathsTuple("vae_approx", [os.path.join(models_dir, "vae_approx")], set(supported_pt_extensions)) -folder_names_and_paths["controlnet"] = FolderPathsTuple("controlnet", [os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], set(supported_pt_extensions)) -folder_names_and_paths["gligen"] = FolderPathsTuple("gligen", [os.path.join(models_dir, "gligen")], set(supported_pt_extensions)) -folder_names_and_paths["upscale_models"] = FolderPathsTuple("upscale_models", [os.path.join(models_dir, "upscale_models")], set(supported_pt_extensions)) -folder_names_and_paths["custom_nodes"] = FolderPathsTuple("custom_nodes", [os.path.join(get_base_path(), "custom_nodes")], set()) -folder_names_and_paths["hypernetworks"] = FolderPathsTuple("hypernetworks", [os.path.join(models_dir, "hypernetworks")], set(supported_pt_extensions)) -folder_names_and_paths["photomaker"] = FolderPathsTuple("photomaker", [os.path.join(models_dir, "photomaker")], set(supported_pt_extensions)) -folder_names_and_paths["classifiers"] = FolderPathsTuple("classifiers", [os.path.join(models_dir, "classifiers")], {""}) -folder_names_and_paths["huggingface"] = FolderPathsTuple("huggingface", [os.path.join(models_dir, "huggingface")], {""}) -folder_names_and_paths["huggingface_cache"] = FolderPathsTuple("huggingface_cache", [os.path.join(models_dir, "huggingface_cache")], {""}) - -output_directory = os.path.join(get_base_path(), "output") -temp_directory = os.path.join(get_base_path(), "temp") -input_directory = os.path.join(get_base_path(), "input") -user_directory = os.path.join(get_base_path(), "user") - -filename_list_cache = {} +def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optional[Configuration] = None): + from ..cmd.main_pre import args + configuration = configuration or args + base_paths = [Path(configuration.cwd) if configuration.cwd is not None else None] + configuration.base_paths + base_paths = [path for path in base_paths if path is not None] + if len(base_paths) == 0: + base_paths = [Path(os.getcwd())] + for base_path in base_paths: + folder_names_and_paths.add_base_path(base_path) + folder_names_and_paths.add(ModelPaths(["checkpoints"], supported_extensions=set(supported_pt_extensions))) + folder_names_and_paths.add(ModelPaths(["configs"], additional_absolute_directory_paths={get_package_as_path("comfy.configs")}, supported_extensions={".yaml"})) + folder_names_and_paths.add(ModelPaths(["vae"], supported_extensions={".yaml"})) + folder_names_and_paths.add(ModelPaths(["clip"], supported_extensions={".yaml"})) + folder_names_and_paths.add(ModelPaths(["loras"], supported_extensions={".yaml"})) + folder_names_and_paths.add(ModelPaths(["diffusion_models", "unet"], supported_extensions=set(supported_pt_extensions))) + folder_names_and_paths.add(ModelPaths(["clip_vision"], supported_extensions=set(supported_pt_extensions))) + folder_names_and_paths.add(ModelPaths(["style_models"], supported_extensions=set(supported_pt_extensions))) + folder_names_and_paths.add(ModelPaths(["embeddings"], supported_extensions=set(supported_pt_extensions))) + folder_names_and_paths.add(ModelPaths(["diffusers"], supported_extensions=set())) + folder_names_and_paths.add(ModelPaths(["vae_approx"], supported_extensions=set(supported_pt_extensions))) + folder_names_and_paths.add(ModelPaths(["controlnet", "t2i_adapter"], supported_extensions=set(supported_pt_extensions))) + folder_names_and_paths.add(ModelPaths(["gligen"], supported_extensions=set(supported_pt_extensions))) + folder_names_and_paths.add(ModelPaths(["upscale_models"], supported_extensions=set(supported_pt_extensions))) + folder_names_and_paths.add(ModelPaths(["custom_nodes"], folder_name_base_path_subdir=construct_path(""), supported_extensions=set())) + folder_names_and_paths.add(ModelPaths(["hypernetworks"], supported_extensions=set(supported_pt_extensions))) + folder_names_and_paths.add(ModelPaths(["photomaker"], supported_extensions=set(supported_pt_extensions))) + folder_names_and_paths.add(ModelPaths(["classifiers"], supported_extensions=set())) + folder_names_and_paths.add(ModelPaths(["huggingface"], supported_extensions=set())) + hf_cache_paths = ModelPaths(["huggingface_cache"], supported_extensions=set()) + # TODO: explore if there is a better way to do this + if "HF_HUB_CACHE" in os.environ: + hf_cache_paths.additional_absolute_directory_paths.add(os.environ.get("HF_HUB_CACHE")) + folder_names_and_paths.add(hf_cache_paths) + create_directories(folder_names_and_paths) -class CacheHelper: - """ - Helper class for managing file list cache data. - """ - - def __init__(self): - self.cache: dict[str, tuple[list[str], dict[str, float], float]] = {} - self.active = False - - def get(self, key: str, default=None) -> tuple[list[str], dict[str, float], float]: - if not self.active: - return default - return self.cache.get(key, default) - - def set(self, key: str, value: tuple[list[str], dict[str, float], float]) -> None: - if self.active: - self.cache[key] = value - - def clear(self): - self.cache.clear() - - def __enter__(self): - self.active = True - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.active = False - self.clear() +@_module_properties.getter +def _folder_names_and_paths(): + return current_execution_context().folder_names_and_paths -cache_helper = CacheHelper() +@_module_properties.getter +def _models_dir(): + return str(Path(current_execution_context().folder_names_and_paths.base_paths[0]) / construct_path("models")) +@_module_properties.getter +def _user_directory() -> str: + return str(Path(current_execution_context().folder_names_and_paths.application_paths.user_directory).resolve()) + + +@_module_properties.getter +def _temp_directory() -> str: + return str(Path(current_execution_context().folder_names_and_paths.application_paths.temp_directory).resolve()) + + +@_module_properties.getter +def _input_directory() -> str: + return str(Path(current_execution_context().folder_names_and_paths.application_paths.input_directory).resolve()) + + +@_module_properties.getter +def _output_directory() -> str: + return str(Path(current_execution_context().folder_names_and_paths.application_paths.output_directory).resolve()) + + +@_deprecate_method(version="0.2.3", message="Mapping of previous folder names is already done by other mechanisms.") def map_legacy(folder_name: str) -> str: legacy = {"unet": "diffusion_models"} return legacy.get(folder_name, folder_name) -if not os.path.exists(input_directory): - try: - os.makedirs(input_directory) - except: - logging.error("Failed to create input directory") +def set_output_directory(output_dir: str | Path): + _folder_names_and_paths().application_paths.output_directory = construct_path(output_dir) -def set_output_directory(output_dir): - global output_directory - output_directory = output_dir +def set_temp_directory(temp_dir: str | Path): + _folder_names_and_paths().application_paths.temp_directory = construct_path(temp_dir) -def set_temp_directory(temp_dir): - global temp_directory - temp_directory = temp_dir +def set_input_directory(input_dir: str | Path): + _folder_names_and_paths().application_paths.input_directory = construct_path(input_dir) -def set_input_directory(input_dir): - global input_directory - input_directory = input_dir +def get_output_directory() -> str: + return str(Path(_folder_names_and_paths().application_paths.output_directory).resolve()) -def get_output_directory(): - global output_directory - return output_directory +def get_temp_directory() -> str: + return str(Path(_folder_names_and_paths().application_paths.temp_directory).resolve()) -def get_temp_directory(): - global temp_directory - return temp_directory - - -def get_input_directory(): - global input_directory - return input_directory +def get_input_directory() -> str: + return str(Path(_folder_names_and_paths().application_paths.input_directory).resolve()) def get_user_directory() -> str: - return user_directory + return str(Path(_folder_names_and_paths().application_paths.user_directory).resolve()) -def set_user_directory(user_dir: str) -> None: - global user_directory - user_directory = user_dir +def set_user_directory(user_dir: str | Path) -> None: + _folder_names_and_paths().application_paths.user_directory = construct_path(user_dir) # NOTE: used in http server so don't put folders that should not be accessed remotely -def get_directory_by_type(type_name): +def get_directory_by_type(type_name) -> str | None: if type_name == "output": return get_output_directory() if type_name == "temp": @@ -151,7 +153,7 @@ def get_directory_by_type(type_name): # determine base_dir rely on annotation if name is 'filename.ext [annotation]' format # otherwise use default_path as base_dir -def annotated_filepath(name): +def annotated_filepath(name: str) -> tuple[str, str | None]: if name.endswith("[output]"): base_dir = get_output_directory() name = name[:-9] @@ -167,7 +169,7 @@ def annotated_filepath(name): return name, base_dir -def get_annotated_filepath(name, default_dir=None): +def get_annotated_filepath(name, default_dir=None) -> str: name, base_dir = annotated_filepath(name) if base_dir is None: @@ -198,9 +200,11 @@ def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None, e :param extensions: supported file extensions :return: the folder path """ - global folder_names_and_paths + folder_names_and_paths = _folder_names_and_paths() if full_folder_path is None: - full_folder_path = os.path.join(models_dir, folder_name) + # todo: this should use the subdir patter + + full_folder_path = os.path.join(_models_dir(), folder_name) folder_path = folder_names_and_paths[folder_name] if full_folder_path not in folder_path.paths: @@ -212,48 +216,16 @@ def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None, e if extensions is not None: folder_path.supported_extensions |= extensions - invalidate_cache(folder_name) return full_folder_path def get_folder_paths(folder_name) -> List[str]: - return folder_names_and_paths[folder_name].paths[:] + return [path for path in _folder_names_and_paths()[folder_name].paths] -def recursive_search(directory, excluded_dir_names=None): - if not os.path.isdir(directory): - return [], {} - - if excluded_dir_names is None: - excluded_dir_names = [] - - result = [] - dirs = {} - - # Attempt to add the initial directory to dirs with error handling - try: - dirs[directory] = os.path.getmtime(directory) - except FileNotFoundError: - logging.warning(f"Warning: Unable to access {directory}. Skipping this path.") - - for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True): - subdirs[:] = [d for d in subdirs if d not in excluded_dir_names] - for file_name in filenames: - try: - relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory) - result.append(relative_path) - except: - logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.") - continue - - for d in subdirs: - path = os.path.join(dirpath, d) - try: - dirs[path] = os.path.getmtime(path) - except FileNotFoundError: - logging.warning(f"Warning: Unable to access {path}. Skipping this path.") - continue - return result, dirs +@_deprecate_method(version="0.2.3", message="Not supported") +def recursive_search(directory, excluded_dir_names=None) -> tuple[list[str], dict[str, float]]: + raise NotImplemented("Unsupported method") def filter_files_extensions(files, extensions): @@ -264,92 +236,27 @@ def get_full_path(folder_name, filename) -> Optional[str | bytes | os.PathLike]: """ Gets the path to a filename inside a folder. - Works with untrusted filenames. :param folder_name: :param filename: :return: """ - global folder_names_and_paths - folders = folder_names_and_paths[folder_name].paths - filename_split = os.path.split(filename) - - trusted_paths = [] - for folder in folders: - folder_split = os.path.split(folder) - abs_file_path = os.path.abspath(os.path.join(*folder_split, *filename_split)) - abs_folder_path = os.path.abspath(folder) - if os.path.commonpath([abs_file_path, abs_folder_path]) == abs_folder_path: - trusted_paths.append(abs_file_path) - else: - logging.error(f"attempted to access untrusted path {abs_file_path} in {folder_name} for filename {filename}") - - for trusted_path in trusted_paths: - if os.path.isfile(trusted_path): - return trusted_path - - return None + path = _folder_names_and_paths().first_existing_or_none(folder_name, construct_path(filename)) + return str(path) if path is not None else None def get_full_path_or_raise(folder_name: str, filename: str) -> str: full_path = get_full_path(folder_name, filename) if full_path is None: + # todo: probably shouldn't say model raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.") return full_path -def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]: - folder_name = map_legacy(folder_name) - global folder_names_and_paths - output_list = set() - folders = folder_names_and_paths[folder_name] - output_folders = {} - for x in folders[0]: - files, folders_all = recursive_search(x, excluded_dir_names=[".git"]) - output_list.update(filter_files_extensions(files, folders[1])) - output_folders = {**output_folders, **folders_all} - - return sorted(list(output_list)), output_folders, time.perf_counter() - - -def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None: - strong_cache = cache_helper.get(folder_name) - if strong_cache is not None: - return strong_cache - - global filename_list_cache - global folder_names_and_paths - folder_name = map_legacy(folder_name) - if folder_name not in filename_list_cache: - return None - out = filename_list_cache[folder_name] - - for x in out[1]: - time_modified = out[1][x] - folder = x - if os.path.getmtime(folder) != time_modified: - return None - - folders = folder_names_and_paths[folder_name] - for x in folders[0]: - if os.path.isdir(x): - if x not in out[1]: - return None - - return out - - def get_filename_list(folder_name: str) -> list[str]: - folder_name = map_legacy(folder_name) - out = cached_filename_list_(folder_name) - if out is None: - out = get_filename_list_(folder_name) - global filename_list_cache - filename_list_cache[folder_name] = out - cache_helper.set(folder_name, out) - return list(out[0]) + return [str(path) for path in _folder_names_and_paths().file_paths(folder_name=folder_name, relative=True)] -def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0): +def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0) -> SaveImagePathTuple: def map_filename(filename: str) -> tuple[int, str]: prefix_len = len(os.path.basename(filename_prefix)) prefix = filename[:prefix_len + 1] @@ -378,14 +285,6 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height full_output_folder = str(os.path.join(output_dir, subfolder)) - if str(os.path.commonpath((output_dir, os.path.abspath(full_output_folder)))) != str(output_dir): - err = f"""**** ERROR: Saving image outside the output folder is not allowed. - full_output_folder: {os.path.abspath(full_output_folder)} - output_dir: {output_dir} - commonpath: {os.path.commonpath((output_dir, os.path.abspath(full_output_folder)))}""" - logging.error(err) - raise Exception(err) - try: counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1 except ValueError: @@ -393,21 +292,22 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height except FileNotFoundError: os.makedirs(full_output_folder, exist_ok=True) counter = 1 - return SaveImagePathResponse(full_output_folder, filename, counter, subfolder, filename_prefix) + return SaveImagePathTuple(full_output_folder, filename, counter, subfolder, filename_prefix) -def create_directories(): +def create_directories(paths: FolderNames | None): # all configured paths should be created - for folder_path_spec in folder_names_and_paths.values(): + paths = paths or _folder_names_and_paths() + for folder_path_spec in paths.values(): for path in folder_path_spec.paths: os.makedirs(path, exist_ok=True) - for path in (temp_directory, input_directory, output_directory, user_directory): - os.makedirs(path, exist_ok=True) + for path in paths.application_paths: + path.mkdir(exist_ok=True) +@_deprecate_method(version="0.2.3", message="Caching has been removed.") def invalidate_cache(folder_name): - global filename_list_cache - filename_list_cache.pop(folder_name, None) + pass def filter_files_content_types(files: list[str], content_types: list[Literal["image", "video", "audio"]]) -> list[str]: @@ -416,7 +316,7 @@ def filter_files_content_types(files: list[str], content_types: list[Literal["im files = os.listdir(folder_paths.get_input_directory()) filter_files_content_types(files, ["image", "audio", "video"]) """ - global extension_mimetypes_cache + extension_mimetypes_cache = _extension_mimetypes_cache() result = [] for file in files: extension = file.split('.')[-1] @@ -432,3 +332,52 @@ def filter_files_content_types(files: list[str], content_types: list[Literal["im if content_type in content_types: result.append(file) return result + + +@_module_properties.getter +def _cache_helper(): + return nullcontext() + + +# todo: can this be done side effect free? +init_default_paths(_folder_names_and_paths()) + +__all__ = [ + # Properties (stripped leading underscore) + "supported_pt_extensions", # from _supported_pt_extensions + "extension_mimetypes_cache", # from _extension_mimetypes_cache + "base_path", # from _base_path + "folder_names_and_paths", # from _folder_names_and_paths + "models_dir", # from _models_dir + "user_directory", + "output_directory", + "temp_directory", + "input_directory", + + # Public functions + "init_default_paths", + "map_legacy", + "set_output_directory", + "set_temp_directory", + "set_input_directory", + "get_output_directory", + "get_temp_directory", + "get_input_directory", + "get_user_directory", + "set_user_directory", + "get_directory_by_type", + "annotated_filepath", + "get_annotated_filepath", + "exists_annotated_filepath", + "add_model_folder_path", + "get_folder_paths", + "recursive_search", + "filter_files_extensions", + "get_full_path", + "get_full_path_or_raise", + "get_filename_list", + "get_save_image_path", + "create_directories", + "invalidate_cache", + "filter_files_content_types" +] diff --git a/comfy/cmd/folder_paths_pre.py b/comfy/cmd/folder_paths_pre.py deleted file mode 100644 index dbf0d5885..000000000 --- a/comfy/cmd/folder_paths_pre.py +++ /dev/null @@ -1,31 +0,0 @@ -import logging -import os - -from ..cli_args import args - -_base_path = None - - -# todo: this should be initialized elsewhere in a context -def get_base_path() -> str: - global _base_path - if _base_path is None: - if args.cwd is not None: - if not os.path.exists(args.cwd): - try: - os.makedirs(args.cwd, exist_ok=True) - except: - logging.error("Failed to create custom working directory") - # wrap the path to prevent slashedness from glitching out common path checks - _base_path = os.path.realpath(args.cwd) - else: - _base_path = os.getcwd() - return _base_path - - -def set_base_path(value: str): - global _base_path - _base_path = value - - -__all__ = ["get_base_path", "set_base_path"] diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index dd98e8ea9..2e3fccf75 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -1,4 +1,5 @@ import asyncio +import contextvars import gc import itertools import logging @@ -193,7 +194,9 @@ async def main(from_script_dir: Optional[Path] = None): if not distributed or args.distributed_queue_worker: if distributed: logging.warning(f"Distributed workers started in the default thread loop cannot notify clients of progress updates. Instead of comfyui or main.py, use comfyui-worker.") - threading.Thread(target=prompt_worker, daemon=True, args=(q, worker_thread_server,)).start() + # todo: this should really be using an executor instead of doing things this jankilicious way + ctx = contextvars.copy_context() + threading.Thread(target=lambda _q, _worker_thread_server: ctx.run(prompt_worker, _q, _worker_thread_server), daemon=True, args=(q, worker_thread_server,)).start() # server has been imported and things should be looking good initialize_event_tracking(loop) diff --git a/comfy/cmd/main_pre.py b/comfy/cmd/main_pre.py index 49457b09e..3c7b2e552 100644 --- a/comfy/cmd/main_pre.py +++ b/comfy/cmd/main_pre.py @@ -111,13 +111,7 @@ def _create_tracer(): def _configure_logging(): - if isinstance(args.verbose, str): - logging_level = args.verbose - elif args.verbose == True: - logging_level = logging.DEBUG - else: - logging_level = logging.ERROR - + logging_level = args.logging_level logging.basicConfig(format="%(message)s", level=logging_level) diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index 4de6279e9..022b80baf 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -28,8 +28,8 @@ from aiohttp import web from can_ada import URL, parse as urlparse # pylint: disable=no-name-in-module from typing_extensions import NamedTuple -from .. import __version__ from .latent_preview_image_encoding import encode_preview_image +from .. import __version__ from .. import interruption, model_management from .. import node_helpers from .. import utils @@ -241,7 +241,7 @@ class PromptServer(ExecutorToClientProgress): return response @routes.get("/embeddings") - def get_embeddings(self): + def get_embeddings(request): embeddings = folder_paths.get_filename_list("embeddings") return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings))) @@ -568,15 +568,14 @@ class PromptServer(ExecutorToClientProgress): @routes.get("/object_info") async def get_object_info(request): - with folder_paths.cache_helper: - out = {} - for x in self.nodes.NODE_CLASS_MAPPINGS: - try: - out[x] = node_info(x) - except Exception as e: - logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.") - logging.error(traceback.format_exc()) - return web.json_response(out) + out = {} + for x in self.nodes.NODE_CLASS_MAPPINGS: + try: + out[x] = node_info(x) + except Exception as e: + logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.") + logging.error(traceback.format_exc()) + return web.json_response(out) @routes.get("/object_info/{node_class}") async def get_object_info_node(request): @@ -721,21 +720,21 @@ class PromptServer(ExecutorToClientProgress): async def get_api_v1_prompts_prompt_id(request: web.Request) -> web.Response | web.FileResponse: prompt_id: str = request.match_info.get("prompt_id", "") if prompt_id == "": - return web.Response(status=404) + return web.json_response(status=404) history_items = self.prompt_queue.get_history(prompt_id) if len(history_items) == 0 or prompt_id not in history_items: # todo: this should really be moved to a stateful queue abstraction if prompt_id in self.background_tasks: - return web.Response(status=204) + return web.json_response(status=204) else: # todo: this should check a stateful queue abstraction - return web.Response(status=404) + return web.json_response(status=404) elif prompt_id in history_items: history_entry = history_items[prompt_id] return web.json_response(history_entry["outputs"]) else: - return web.Response(status=500) + return web.json_response(status=500) @routes.post("/api/v1/prompts") async def post_api_prompt(request: web.Request) -> web.Response | web.FileResponse: @@ -751,8 +750,8 @@ class PromptServer(ExecutorToClientProgress): queue_size = self.prompt_queue.size() queue_too_busy_size = PromptServer.get_too_busy_queue_size() if queue_size > queue_too_busy_size: - return web.Response(status=429, - reason=f"the queue has {queue_size} elements and {queue_too_busy_size} is the limit for this worker") + return web.json_response(status=429, + reason=f"the queue has {queue_size} elements and {queue_too_busy_size} is the limit for this worker") # read the request prompt_dict: dict = {} if content_type == 'application/json': diff --git a/comfy/cmd/worker.py b/comfy/cmd/worker.py index 37a795af6..1cecdda88 100644 --- a/comfy/cmd/worker.py +++ b/comfy/cmd/worker.py @@ -2,10 +2,11 @@ import asyncio import itertools import logging import os -from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor +from concurrent.futures import ProcessPoolExecutor -from .extra_model_paths import load_extra_path_config from .main_pre import args +from .extra_model_paths import load_extra_path_config +from ..distributed.executors import ContextVarExecutor async def main(): @@ -43,7 +44,7 @@ async def main(): from ..distributed.distributed_prompt_worker import DistributedPromptWorker async with DistributedPromptWorker(connection_uri=args.distributed_queue_connection_uri, queue_name=args.distributed_queue_name, - executor=ThreadPoolExecutor(max_workers=1) if args.executor_factory == "ThreadPoolExecutor" else ProcessPoolExecutor(max_workers=1)): + executor=ContextVarExecutor(max_workers=1) if args.executor_factory == "ThreadPoolExecutor" else ProcessPoolExecutor(max_workers=1)): stop = asyncio.Event() try: await stop.wait() diff --git a/comfy/component_model/executor_types.py b/comfy/component_model/executor_types.py index 32bc1d22d..dd73cabf0 100644 --- a/comfy/component_model/executor_types.py +++ b/comfy/component_model/executor_types.py @@ -8,6 +8,7 @@ from typing import Optional, Literal, Protocol, Union, NamedTuple, List import PIL.Image from typing_extensions import NotRequired, TypedDict +from .outputs_types import OutputsDict from .queue_types import BinaryEventTypes from ..cli_args_types import Configuration from ..nodes.package_typing import InputTypeSpec @@ -205,8 +206,8 @@ class DuplicateNodeError(Exception): class HistoryResultDict(TypedDict, total=True): - outputs: dict - meta: dict + outputs: OutputsDict + meta: OutputsDict class DependencyCycleError(Exception): diff --git a/comfy/component_model/file_output_path.py b/comfy/component_model/file_output_path.py index ed2182706..dcb7e8bc4 100644 --- a/comfy/component_model/file_output_path.py +++ b/comfy/component_model/file_output_path.py @@ -1,15 +1,9 @@ -import os -from typing import Literal, Optional from pathlib import Path +from typing import Literal, Optional from ..cmd import folder_paths -def _is_strictly_below_root(path: Path) -> bool: - resolved_path = path.resolve() - return ".." not in resolved_path.parts and resolved_path.is_absolute() - - def file_output_path(filename: str, type: Literal["input", "output", "temp"] = "output", subfolder: Optional[str] = None) -> str: """ @@ -22,22 +16,15 @@ def file_output_path(filename: str, type: Literal["input", "output", "temp"] = " :return: """ filename, output_dir = folder_paths.annotated_filepath(str(filename)) - if not _is_strictly_below_root(Path(filename)): - raise PermissionError("insecure") if output_dir is None: output_dir = folder_paths.get_directory_by_type(type) if output_dir is None: raise ValueError(f"no such output directory because invalid type specified (type={type})") - if subfolder is not None and subfolder != "": - full_output_dir = str(os.path.join(output_dir, subfolder)) - if str(os.path.commonpath([os.path.abspath(full_output_dir), output_dir])) != str(output_dir): - raise PermissionError("insecure") - output_dir = full_output_dir - filename = os.path.basename(filename) - else: - if str(os.path.commonpath([os.path.abspath(output_dir), os.path.join(output_dir, filename)])) != str(output_dir): - raise PermissionError("insecure") - - file = os.path.join(output_dir, filename) - return file + output_dir = Path(output_dir) + subfolder = Path(subfolder or "") + try: + relative_to = (output_dir / subfolder / filename).relative_to(output_dir) + except ValueError: + raise PermissionError(f"{output_dir / subfolder / filename} is not a subpath of {output_dir}") + return str((output_dir / relative_to).resolve(strict=True)) diff --git a/comfy/component_model/folder_path_types.py b/comfy/component_model/folder_path_types.py index 1feaf9b44..e259f4644 100644 --- a/comfy/component_model/folder_path_types.py +++ b/comfy/component_model/folder_path_types.py @@ -1,9 +1,15 @@ from __future__ import annotations import dataclasses +import itertools import os import typing -from typing import List, Set, Any, Iterator, Sequence, Dict, NamedTuple +import weakref +from abc import ABC, abstractmethod +from typing import Any, NamedTuple, Optional, Iterable + +from pathlib import Path +from .platform_path import construct_path supported_pt_extensions = frozenset(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft', ".index.json"]) extension_mimetypes_cache = { @@ -11,35 +17,267 @@ extension_mimetypes_cache = { } -@dataclasses.dataclass +def do_add(collection: list | set, index: int | None, item: Any): + if isinstance(collection, list) and index == 0: + collection.insert(0, item) + elif isinstance(collection, set): + collection.add(item) + else: + assert isinstance(collection, list) + collection.append(item) + + class FolderPathsTuple: - folder_name: str - paths: List[str] = dataclasses.field(default_factory=list) - supported_extensions: Set[str] = dataclasses.field(default_factory=lambda: set(supported_pt_extensions)) + def __init__(self, folder_name: str = None, paths: list[str] = None, supported_extensions: set[str] = None, parent: Optional[weakref.ReferenceType[FolderNames]] = None): + paths = paths or [] + supported_extensions = supported_extensions or set(supported_pt_extensions) + + self.folder_name = folder_name + self.parent = parent + self._paths = paths + self._supported_extensions = supported_extensions + + @property + def supported_extensions(self) -> set[str] | SupportedExtensions: + if self.parent is not None and self.folder_name is not None: + return SupportedExtensions(self.folder_name, self.parent) + else: + return self._supported_extensions + + @supported_extensions.setter + def supported_extensions(self, value: set[str] | SupportedExtensions): + self.supported_extensions.clear() + self.supported_extensions.update(value) + + @property + def paths(self) -> list[str] | PathsList: + if self.parent is not None and self.folder_name is not None: + return PathsList(self.folder_name, self.parent) + else: + return self._paths + + def __iter__(self) -> typing.Generator[typing.Iterable[str]]: + """ + allows this proxy to behave like a tuple everywhere it is used by the custom nodes + :return: + """ + yield self.paths + yield self.supported_extensions def __getitem__(self, item: Any): if item == 0: return self.paths if item == 1: return self.supported_extensions - else: - raise RuntimeError("unsupported tuple index") - def __add__(self, other: "FolderPathsTuple"): - assert self.folder_name == other.folder_name - # todo: make sure the paths are actually unique, as this method intends - new_paths = list(frozenset(self.paths + other.paths)) - new_supported_extensions = self.supported_extensions | other.supported_extensions - return FolderPathsTuple(self.folder_name, new_paths, new_supported_extensions) + raise RuntimeError("unsupported tuple index") - def __iter__(self) -> Iterator[Sequence[str]]: - yield self.paths - yield self.supported_extensions + def __iadd__(self, other: FolderPathsTuple): + for path in other.paths: + self.paths.append(path) + + for ext in other.supported_extensions: + self.supported_extensions.add(ext) +@dataclasses.dataclass +class PathsList: + folder_name: str + parent: weakref.ReferenceType[FolderNames] + + def __iter__(self) -> typing.Generator[str]: + p: FolderNames = self.parent() + for path in p.directory_paths(self.folder_name): + try: + yield str(path.resolve()) + except (OSError, AttributeError): + yield str(path) + + def __getitem__(self, item: int): + paths = [x for x in self] + return paths[item] + + def append(self, path_str: str): + p: FolderNames = self.parent() + p.add_paths(self.folder_name, [path_str]) + + def insert(self, path_str: str, index: int): + p: FolderNames = self.parent() + p.add_paths(self.folder_name, [path_str], index=index) + + +@dataclasses.dataclass +class SupportedExtensions: + folder_name: str + parent: weakref.ReferenceType[FolderNames] + + def _append_any(self, other): + if other is None: + return + + p: FolderNames = self.parent() + if isinstance(other, str): + other = {other} + p.add_supported_extension(self.folder_name, *other) + + def __iter__(self) -> typing.Generator[str]: + p: FolderNames = self.parent() + for ext in p.supported_extensions(self.folder_name): + yield ext + + def clear(self): + p: FolderNames = self.parent() + p.remove_all_supported_extensions(self.folder_name) + + __ior__ = _append_any + add = _append_any + update = _append_any + + +@dataclasses.dataclass +class ApplicationPaths: + output_directory: Path = dataclasses.field(default_factory=lambda: construct_path("output")) + temp_directory: Path = dataclasses.field(default_factory=lambda: construct_path("temp")) + input_directory: Path = dataclasses.field(default_factory=lambda: construct_path("input")) + user_directory: Path = dataclasses.field(default_factory=lambda: construct_path("user")) + + def __iter__(self) -> typing.Generator[Path]: + yield self.output_directory + yield self.temp_directory + yield self.input_directory + yield self.user_directory + + +@dataclasses.dataclass +class AbstractPaths(ABC): + folder_names: list[str] + supported_extensions: set[str] = dataclasses.field(default_factory=lambda: set(supported_pt_extensions)) + + @abstractmethod + def directory_paths(self, base_paths: Iterable[Path]) -> typing.Generator[Path]: + """Generate directory paths based on the given base paths.""" + pass + + @abstractmethod + def file_paths(self, base_paths: Iterable[Path], relative=False) -> typing.Generator[Path]: + """Generate file paths based on the given base paths.""" + pass + + @abstractmethod + def has_folder_name(self, folder_name: str) -> bool: + """Check if the given folder name is in folder_names.""" + pass + + +@dataclasses.dataclass +class ModelPaths(AbstractPaths): + folder_name_base_path_subdir: Path = dataclasses.field(default_factory=lambda: construct_path("models")) + additional_relative_directory_paths: set[Path] = dataclasses.field(default_factory=set) + additional_absolute_directory_paths: set[str | Path] = dataclasses.field(default_factory=set) + folder_names_are_relative_directory_paths_too: bool = dataclasses.field(default_factory=lambda: True) + + def directory_paths(self, base_paths: Iterable[Path]) -> typing.Generator[Path]: + yielded_so_far: set[Path] = set() + for base_path in base_paths: + if self.folder_names_are_relative_directory_paths_too: + for folder_name in self.folder_names: + resolved_default_folder_name_path = base_path / self.folder_name_base_path_subdir / folder_name + if not resolved_default_folder_name_path in yielded_so_far: + yield resolved_default_folder_name_path + yielded_so_far.add(resolved_default_folder_name_path) + + for additional_relative_directory_path in self.additional_relative_directory_paths: + resolved_additional_relative_path = base_path / additional_relative_directory_path + if not resolved_additional_relative_path in yielded_so_far: + yield resolved_additional_relative_path + yielded_so_far.add(resolved_additional_relative_path) + + # resolve all paths + yielded_so_far = {path.resolve() for path in yielded_so_far} + for additional_absolute_path in self.additional_absolute_directory_paths: + try: + resolved_absolute_path = additional_absolute_path.resolve() + except (OSError, AttributeError): + resolved_absolute_path = additional_absolute_path + if not resolved_absolute_path in yielded_so_far: + yield resolved_absolute_path + yielded_so_far.add(resolved_absolute_path) + + def file_paths(self, base_paths: Iterable[Path], relative=False) -> typing.Generator[Path]: + for path in self.directory_paths(base_paths): + for dirpath, dirnames, filenames in os.walk(path, followlinks=True): + if '.git' in dirnames: + dirnames.remove('.git') + + for filename in filenames: + if any(filename.endswith(ext) for ext in self.supported_extensions): + result_path = construct_path(dirpath) / filename + if relative: + yield result_path.relative_to(path) + else: + yield result_path + + def has_folder_name(self, folder_name: str) -> bool: + return folder_name in self.folder_names + + +@dataclasses.dataclass class FolderNames: + application_paths: typing.Optional[ApplicationPaths] = dataclasses.field(default_factory=ApplicationPaths) + contents: list[AbstractPaths] = dataclasses.field(default_factory=list) + base_paths: list[Path] = dataclasses.field(default_factory=list) + + def supported_extensions(self, folder_name: str) -> typing.Generator[str]: + for candidate in self.contents: + if candidate.has_folder_name(folder_name): + for supported_extensions in candidate.supported_extensions: + for supported_extension in supported_extensions: + yield supported_extension + + def directory_paths(self, folder_name: str) -> typing.Generator[Path]: + for directory_path in itertools.chain.from_iterable([candidate.directory_paths(self.base_paths) + for candidate in self.contents if candidate.has_folder_name(folder_name)]): + yield directory_path + + def file_paths(self, folder_name: str, relative=False) -> typing.Generator[Path]: + for file_path in itertools.chain.from_iterable([candidate.file_paths(self.base_paths, relative=relative) + for candidate in self.contents if candidate.has_folder_name(folder_name)]): + yield file_path + + def first_existing_or_none(self, folder_name: str, relative_file_path: Path) -> Optional[Path]: + for directory_path in itertools.chain.from_iterable([candidate.directory_paths(self.base_paths) + for candidate in self.contents if candidate.has_folder_name(folder_name)]): + candidate_file_path: Path = construct_path(directory_path / relative_file_path) + try: + # todo: this should follow the symlink + if Path.exists(candidate_file_path): + return candidate_file_path + except OSError: + continue + return None + + def add_supported_extension(self, folder_name: str, *supported_extensions: str | None): + if supported_extensions is None: + return + + for candidate in self.contents: + if candidate.has_folder_name(folder_name): + candidate.supported_extensions.update(supported_extensions) + + def remove_all_supported_extensions(self, folder_name: str): + for candidate in self.contents: + if candidate.has_folder_name(folder_name): + candidate.supported_extensions.clear() + + def add(self, model_paths: AbstractPaths): + self.contents.append(model_paths) + + def add_base_path(self, base_path: Path): + if base_path not in self.base_paths: + self.base_paths.append(base_path) + @staticmethod - def from_dict(folder_paths_dict: dict[str, tuple[typing.Sequence[str], Sequence[str]]] = None) -> FolderNames: + def from_dict(folder_paths_dict: dict[str, tuple[typing.Iterable[str], Iterable[str]]] = None) -> FolderNames: """ Turns a dictionary of { @@ -50,65 +288,179 @@ class FolderNames: :param folder_paths_dict: A dictionary :return: A FolderNames object """ - if folder_paths_dict is None: - return FolderNames(os.getcwd()) + raise NotImplementedError() - fn = FolderNames(os.getcwd()) - for folder_name, (paths, extensions) in folder_paths_dict.items(): - paths_tuple = FolderPathsTuple(folder_name=folder_name, paths=list(paths), supported_extensions=set(extensions)) - - if folder_name in fn: - fn[folder_name] += paths_tuple - else: - fn[folder_name] = paths_tuple - return fn - - def __init__(self, default_new_folder_path: str): - self.contents: Dict[str, FolderPathsTuple] = dict() - self.default_new_folder_path = default_new_folder_path - - def __getitem__(self, item) -> FolderPathsTuple: - if not isinstance(item, str): + def __getitem__(self, folder_name) -> FolderPathsTuple: + if not isinstance(folder_name, str) or folder_name is None: raise RuntimeError("expected folder path") - if item not in self.contents: - default_path = os.path.join(self.default_new_folder_path, item) - os.makedirs(default_path, exist_ok=True) - self.contents[item] = FolderPathsTuple(item, paths=[default_path], supported_extensions=set()) - return self.contents[item] + # todo: it is probably never the intention to do this + try: + path = Path(folder_name) + if path.is_absolute(): + folder_name = path.stem + except Exception: + pass + if not any(candidate.has_folder_name(folder_name) for candidate in self.contents): + self.add(ModelPaths(folder_names=[folder_name], folder_name_base_path_subdir=construct_path(), supported_extensions=set(), folder_names_are_relative_directory_paths_too=False)) + return FolderPathsTuple(folder_name, parent=weakref.ref(self)) - def __setitem__(self, key: str, value: FolderPathsTuple): - assert isinstance(key, str) - if isinstance(value, tuple): + def add_paths(self, folder_name: str, paths: list[Path | str], index: Optional[int] = None): + """ + Adds, but does not create, new model paths + :param folder_name: + :param paths: + :param index: + :return: + """ + for candidate in self.contents: + if candidate.has_folder_name(folder_name): + self._modify_model_paths(folder_name, paths, set(), candidate, index=index) + + def _modify_model_paths(self, key: str, paths: Iterable[Path | str], supported_extensions: set[str], model_paths: AbstractPaths = None, index: Optional[int] = None) -> AbstractPaths: + model_paths = model_paths or ModelPaths([key], + supported_extensions=set(supported_extensions), + folder_names_are_relative_directory_paths_too=False) + if index is not None and index != 0: + raise ValueError(f"index was {index} but only 0 or None is supported") + + for path in paths: + if isinstance(path, str): + path = construct_path(path) + # when given absolute paths, try to formulate them as relative paths anyway + if path.is_absolute(): + for base_path in self.base_paths: + try: + relative_to_basepath = path.relative_to(base_path) + potential_folder_name = relative_to_basepath.stem + potential_subdir = relative_to_basepath.parent + + # does the folder_name so far match the key? + # or have we not seen this folder before? + folder_name_not_already_in_contents = all(not candidate.has_folder_name(potential_folder_name) for candidate in self.contents) + if potential_folder_name == key or folder_name_not_already_in_contents: + # fix the subdir + model_paths.folder_name_base_path_subdir = potential_subdir + model_paths.folder_names_are_relative_directory_paths_too = True + if folder_name_not_already_in_contents: + do_add(model_paths.folder_names, index, potential_folder_name) + else: + # if the folder name doesn't match the key, check if we have ever seen a folder name that matches the key: + if model_paths.folder_names_are_relative_directory_paths_too: + if potential_subdir == model_paths.folder_name_base_path_subdir: + # then we want to add this to the folder name, because it's probably compatibility + do_add(model_paths.folder_names, index, potential_folder_name) + else: + # not this case + model_paths.folder_names_are_relative_directory_paths_too = False + + if not model_paths.folder_names_are_relative_directory_paths_too: + model_paths.additional_relative_directory_paths.add(relative_to_basepath) + for resolve_folder_name in model_paths.folder_names: + model_paths.additional_relative_directory_paths.add(model_paths.folder_name_base_path_subdir / resolve_folder_name) + + # since this was an absolute path that was a subdirectory of one of the base paths, + # we are done + break + except ValueError: + # this is not a subpath of the base path + pass + + # if we got this far, none of the absolute paths were subdirectories of any base paths + # add it to our absolute paths + model_paths.additional_absolute_directory_paths.add(path) + else: + # since this is a relative path, peacefully add it to model_paths + potential_folder_name = path.stem + + try: + relative_to_current_subdir = path.relative_to(model_paths.folder_name_base_path_subdir) + + # if relative to the current subdir, we observe only one part, we're good to go + if len(relative_to_current_subdir.parts) == 1: + if potential_folder_name == key: + model_paths.folder_names_are_relative_directory_paths_too = True + else: + # if there already exists a folder_name by this name, do not add it, and switch to all relative paths + if any(candidate.has_folder_name(potential_folder_name) for candidate in self.contents): + model_paths.folder_names_are_relative_directory_paths_too = False + model_paths.additional_relative_directory_paths.add(path) + model_paths.folder_name_base_path_subdir = construct_path() + else: + do_add(model_paths.folder_names, index, potential_folder_name) + except ValueError: + # this means that the relative directory didn't contain the subdir so far + # something_not_models/key + if potential_folder_name == key: + model_paths.folder_name_base_path_subdir = path.parent + model_paths.folder_names_are_relative_directory_paths_too = True + else: + if any(candidate.has_folder_name(potential_folder_name) for candidate in self.contents): + model_paths.folder_names_are_relative_directory_paths_too = False + model_paths.additional_relative_directory_paths.add(path) + model_paths.folder_name_base_path_subdir = construct_path() + else: + do_add(model_paths.folder_names, index, potential_folder_name) + return model_paths + + def __setitem__(self, key: str, value: tuple | FolderPathsTuple | AbstractPaths): + # remove all existing paths for this key + self.__delitem__(key) + + if isinstance(value, AbstractPaths): + self.contents.append(value) + elif isinstance(value, (tuple, FolderPathsTuple)): paths, supported_extensions = value - value = FolderPathsTuple(key, paths, supported_extensions) - self.contents[key] = value + # typical cases: + # key="checkpoints", paths="C:/base_path/models/checkpoints" + # key="unets", paths="C:/base_path/models/unets", "C:/base_path/models/diffusion_models" + # ^ in this case, we will want folder_names to be both unets and diffusion_models + # key="custom_loader", paths="C:/base_path/models/checkpoints" + + # if the paths are subdirectories of any basepath, use relative paths + paths: list[Path] = list(map(Path, paths)) + self.contents.append(self._modify_model_paths(key, paths, supported_extensions)) def __len__(self): return len(self.contents) - def __iter__(self): - return iter(self.contents) + def __iter__(self) -> typing.Generator[str]: + for model_paths in self.contents: + for folder_name in model_paths.folder_names: + yield folder_name def __delitem__(self, key): - del self.contents[key] + to_remove: list[AbstractPaths] = [] + for model_paths in self.contents: + if model_paths.has_folder_name(key): + to_remove.append(model_paths) - def __contains__(self, item): - return item in self.contents + for model_paths in to_remove: + self.contents.remove(model_paths) + + def __contains__(self, item: str): + return any(candidate.has_folder_name(item) for candidate in self.contents) def items(self): - return self.contents.items() + items_view = { + folder_name: self[folder_name] for folder_name in self.keys() + } + return items_view.items() def values(self): - return self.contents.values() + return [self[folder_name] for folder_name in self.keys()] def keys(self): - return self.contents.keys() + return [x for x in self] def get(self, key, __default=None): - return self.contents.get(key, __default) + for candidate in self.contents: + if candidate.has_folder_name(key): + return FolderPathsTuple(key, parent=weakref.ref(self)) + if __default is not None: + raise ValueError("get with default is not supported") -class SaveImagePathResponse(NamedTuple): +class SaveImagePathTuple(NamedTuple): full_output_folder: str filename: str counter: int diff --git a/comfy/component_model/module_property.py b/comfy/component_model/module_property.py index b93f2045a..2673a3f16 100644 --- a/comfy/component_model/module_property.py +++ b/comfy/component_model/module_property.py @@ -1,22 +1,44 @@ import sys +from functools import wraps +def create_module_properties(): + properties = {} + patched_modules = set() -def module_property(func): - """Decorator to turn module functions into properties. - Function names must be prefixed with an underscore.""" - module = sys.modules[func.__module__] + def patch_module(module): + if module in patched_modules: + return - def base_getattr(name): - raise AttributeError( - f"module '{module.__name__}' has no attribute '{name}'") + def base_getattr(name): + raise AttributeError(f"module '{module.__name__}' has no attribute '{name}'") - old_getattr = getattr(module, '__getattr__', base_getattr) + old_getattr = getattr(module, '__getattr__', base_getattr) - def new_getattr(name): - if f'_{name}' == func.__name__: - return func() - else: - return old_getattr(name) + def new_getattr(name): + if name in properties: + return properties[name]() + else: + return old_getattr(name) - module.__getattr__ = new_getattr - return func + module.__getattr__ = new_getattr + patched_modules.add(module) + + class ModuleProperties: + @staticmethod + def getter(func): + @wraps(func) + def wrapper(): + return func() + + name = func.__name__ + if name.startswith('_'): + properties[name[1:]] = wrapper + else: + raise ValueError("Property function names must start with an underscore") + + module = sys.modules[func.__module__] + patch_module(module) + + return wrapper + + return ModuleProperties() \ No newline at end of file diff --git a/comfy/component_model/outputs_types.py b/comfy/component_model/outputs_types.py new file mode 100644 index 000000000..3d406ed01 --- /dev/null +++ b/comfy/component_model/outputs_types.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +import typing + +# for nodes that return: +# { "ui" : { "some_field": Any, "other": Any }} +# the outputs dict will be +# (the node id) +# { "1": { "some_field": Any, "other": Any }} +OutputsDict = dict[str, dict[str, typing.Any]] diff --git a/comfy/component_model/platform_path.py b/comfy/component_model/platform_path.py new file mode 100644 index 000000000..d8500e7d1 --- /dev/null +++ b/comfy/component_model/platform_path.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from pathlib import PurePosixPath, Path, PosixPath + + +def construct_path(*args) -> PurePosixPath | Path: + if len(args) > 0 and args[0] is not None and isinstance(args[0], str) and args[0].startswith("/"): + try: + return PosixPath(*args) + except Exception: + return PurePosixPath(*args) + else: + return Path(*args) diff --git a/comfy/component_model/queue_types.py b/comfy/component_model/queue_types.py index 59476e99e..05a0432d3 100644 --- a/comfy/component_model/queue_types.py +++ b/comfy/component_model/queue_types.py @@ -8,13 +8,15 @@ from typing import Tuple from typing_extensions import NotRequired, TypedDict +from .outputs_types import OutputsDict + QueueTuple = Tuple[float, str, dict, dict, list] MAXIMUM_HISTORY_SIZE = 10000 class TaskInvocation(NamedTuple): item_id: int | str - outputs: dict + outputs: OutputsDict status: Optional[ExecutionStatus] diff --git a/comfy/distributed/distributed_prompt_worker.py b/comfy/distributed/distributed_prompt_worker.py index 79086ac16..88e7b0c6a 100644 --- a/comfy/distributed/distributed_prompt_worker.py +++ b/comfy/distributed/distributed_prompt_worker.py @@ -10,11 +10,12 @@ from aio_pika.patterns import JsonRPC from aiohttp import web from aiormq import AMQPConnectionError +from .executors import ContextVarExecutor from .distributed_progress import DistributedExecutorToClientProgress from .distributed_types import RpcRequest, RpcReply +from .process_pool_executor import ProcessPoolExecutor from ..client.embedded_comfy_client import EmbeddedComfyClient from ..cmd.main_pre import tracer -from ..component_model.executor_types import Executor from ..component_model.queue_types import ExecutionStatus @@ -28,7 +29,7 @@ class DistributedPromptWorker: queue_name: str = "comfyui", health_check_port: int = 9090, loop: Optional[AbstractEventLoop] = None, - executor: Optional[Executor] = None): + executor: Optional[ContextVarExecutor | ProcessPoolExecutor] = None): self._rpc = None self._channel = None self._exit_stack = AsyncExitStack() diff --git a/comfy/distributed/executors.py b/comfy/distributed/executors.py new file mode 100644 index 000000000..626489731 --- /dev/null +++ b/comfy/distributed/executors.py @@ -0,0 +1,24 @@ +import concurrent +import contextvars +import typing +from concurrent.futures import Future, ThreadPoolExecutor +from functools import partial + +__version__ = '0.0.1' + +from .process_pool_executor import ProcessPoolExecutor + + +class ContextVarExecutor(ThreadPoolExecutor): + + def submit(self, fn: typing.Callable, *args, **kwargs) -> Future: + ctx = contextvars.copy_context() # type: contextvars.Context + + return super().submit(partial(ctx.run, partial(fn, *args, **kwargs))) + + +class ContextVarProcessPoolExecutor(ProcessPoolExecutor): + + def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future: + # TODO: serialize the "comfyui_execution_context" + pass diff --git a/comfy/execution_context.py b/comfy/execution_context.py index bec5c5b27..e63bf6e6a 100644 --- a/comfy/execution_context.py +++ b/comfy/execution_context.py @@ -6,6 +6,7 @@ from dataclasses import dataclass, replace from typing import Optional, Final from .component_model.executor_types import ExecutorToClientProgress +from .component_model.folder_path_types import FolderNames from .distributed.server_stub import ServerStub _current_context: Final[ContextVar] = ContextVar("comfyui_execution_context") @@ -14,23 +15,21 @@ _current_context: Final[ContextVar] = ContextVar("comfyui_execution_context") @dataclass(frozen=True) class ExecutionContext: server: ExecutorToClientProgress + folder_names_and_paths: FolderNames node_id: Optional[str] = None task_id: Optional[str] = None inference_mode: bool = True -_empty_execution_context: Final[ExecutionContext] = ExecutionContext(server=ServerStub()) +_current_context.set(ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames())) def current_execution_context() -> ExecutionContext: - try: - return _current_context.get() - except LookupError: - return _empty_execution_context + return _current_context.get() @contextmanager -def new_execution_context(ctx: ExecutionContext): +def _new_execution_context(ctx: ExecutionContext): token = _current_context.set(ctx) try: yield ctx @@ -39,8 +38,24 @@ def new_execution_context(ctx: ExecutionContext): @contextmanager -def context_execute_node(node_id: str, prompt_id: str): +def context_folder_names_and_paths(folder_names_and_paths: FolderNames): current_ctx = current_execution_context() - new_ctx = replace(current_ctx, node_id=node_id, task_id=prompt_id) - with new_execution_context(new_ctx): + new_ctx = replace(current_ctx, folder_names_and_paths=folder_names_and_paths) + with _new_execution_context(new_ctx): + yield new_ctx + + +@contextmanager +def context_execute_prompt(server: ExecutorToClientProgress, prompt_id: str, inference_mode: bool = True): + current_ctx = current_execution_context() + new_ctx = replace(current_ctx, server=server, task_id=prompt_id, inference_mode=inference_mode) + with _new_execution_context(new_ctx): + yield new_ctx + + +@contextmanager +def context_execute_node(node_id: str): + current_ctx = current_execution_context() + new_ctx = replace(current_ctx, node_id=node_id) + with _new_execution_context(new_ctx): yield new_ctx diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index ac9e769b3..c75a3c7d5 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -198,10 +198,6 @@ Visit the repository, accept the terms, and then do one of the following: - Login to Hugging Face in your terminal using `huggingface-cli login` """) raise exc_info - finally: - # a path was found for any reason, so we should invalidate the cache - if path is not None: - folder_paths.invalidate_cache(folder_name) if path is None: raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found, and no download candidates matched for the filename.") return path @@ -504,7 +500,6 @@ def add_known_models(folder_name: str, known_models: Optional[List[Downloadable] pre_existing = frozenset(known_models) known_models.extend([model for model in models if model not in pre_existing]) - folder_paths.invalidate_cache(folder_name) return known_models diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index bdb00ae28..8ba4a6faa 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -2,7 +2,6 @@ import torch import os import json -import hashlib import math import random import logging @@ -1654,7 +1653,7 @@ class LoadImageMask: input_dir = folder_paths.get_input_directory() files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {"required": - {"image": (sorted(files), {"image_upload": True}), + {"image": (natsorted(files), {"image_upload": True}), "channel": (s._color_channels, ), } } diff --git a/comfy/utils.py b/comfy/utils.py index c39343d8b..f41e01ce0 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -29,9 +29,9 @@ import sys import warnings from contextlib import contextmanager from pathlib import Path +from pickle import UnpicklingError from typing import Optional, Any -import accelerate import numpy as np import safetensors.torch import torch @@ -65,7 +65,7 @@ def load_torch_file(ckpt: str, safe_load=False, device=None): if ckpt is None: raise FileNotFoundError("the checkpoint was not found") if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): - sd = safetensors.torch.load_file(ckpt, device=device.type) + sd = safetensors.torch.load_file(Path(ckpt).resolve(strict=True), device=device.type) elif ckpt.lower().endswith("index.json"): # from accelerate index_filename = ckpt @@ -81,20 +81,29 @@ def load_torch_file(ckpt: str, safe_load=False, device=None): for checkpoint_file in checkpoint_files: sd.update(safetensors.torch.load_file(str(checkpoint_file), device=device.type)) else: - if safe_load: - if not 'weights_only' in torch.load.__code__.co_varnames: - logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") - safe_load = False - if safe_load: - pl_sd = torch.load(ckpt, map_location=device, weights_only=True) - else: - pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle) - if "global_step" in pl_sd: - logging.debug(f"Global Step: {pl_sd['global_step']}") - if "state_dict" in pl_sd: - sd = pl_sd["state_dict"] - else: - sd = pl_sd + try: + if safe_load: + if not 'weights_only' in torch.load.__code__.co_varnames: + logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") + safe_load = False + if safe_load: + pl_sd = torch.load(ckpt, map_location=device, weights_only=True) + else: + pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle) + if "global_step" in pl_sd: + logging.debug(f"Global Step: {pl_sd['global_step']}") + if "state_dict" in pl_sd: + sd = pl_sd["state_dict"] + else: + sd = pl_sd + except UnpicklingError as exc_info: + try: + # wrong extension is most likely, try to load as safetensors anyway + sd = safetensors.torch.load_file(Path(ckpt).resolve(strict=True), device=device.type) + return sd + except Exception: + exc_info.add_note(f"The checkpoint at {ckpt} could not be loaded as a safetensor nor a torch checkpoint. The file at the path is corrupted or unexpected. Try deleting it and downloading it again") + raise exc_info return sd diff --git a/comfy_extras/nodes/nodes_language.py b/comfy_extras/nodes/nodes_language.py index 12457256a..b81d0de29 100644 --- a/comfy_extras/nodes/nodes_language.py +++ b/comfy_extras/nodes/nodes_language.py @@ -13,7 +13,7 @@ from transformers.models.nllb.tokenization_nllb import \ FAIRSEQ_LANGUAGE_CODES as tokenization_nllb_FAIRSEQ_LANGUAGE_CODES from comfy.cmd import folder_paths -from comfy.component_model.folder_path_types import SaveImagePathResponse +from comfy.component_model.folder_path_types import SaveImagePathTuple from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES from comfy.language.language_types import GENERATION_KWARGS_TYPE, GENERATION_KWARGS_TYPE_NAME, TOKENS_TYPE, \ TOKENS_TYPE_NAME, LanguageModel @@ -397,7 +397,7 @@ class SaveString(CustomNode): OUTPUT_NODE = True RETURN_TYPES = () - def get_save_path(self, filename_prefix) -> SaveImagePathResponse: + def get_save_path(self, filename_prefix) -> SaveImagePathTuple: return folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory(), 0, 0) def execute(self, value: str | list[str], filename_prefix: str, extension: str = ".json"): diff --git a/comfy_extras/nodes_torch_compile.py b/comfy_extras/nodes_torch_compile.py deleted file mode 100644 index 1d914fa93..000000000 --- a/comfy_extras/nodes_torch_compile.py +++ /dev/null @@ -1,21 +0,0 @@ -import torch - -class TorchCompileModel: - @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" - - CATEGORY = "_for_testing" - EXPERIMENTAL = True - - def patch(self, model): - m = model.clone() - m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"))) - return (m, ) - -NODE_CLASS_MAPPINGS = { - "TorchCompileModel": TorchCompileModel, -} diff --git a/main.py b/main.py index c1ef41468..002c8926d 100644 --- a/main.py +++ b/main.py @@ -2,13 +2,15 @@ import asyncio import warnings from pathlib import Path -if __name__ == "__main__": - from comfy.cmd.folder_paths_pre import set_base_path +from comfy.component_model.folder_path_types import FolderNames +if __name__ == "__main__": warnings.warn("main.py is deprecated. Start comfyui by installing the package through the instructions in the README, not by cloning the repository.", DeprecationWarning) this_file_parent_dir = Path(__file__).parent - set_base_path(str(this_file_parent_dir)) - from comfy.cmd.main import main + from comfy.cmd.folder_paths import folder_names_and_paths # type: FolderNames + fn: FolderNames = folder_names_and_paths + fn.base_paths.clear() + fn.base_paths.append(this_file_parent_dir) asyncio.run(main(from_script_dir=this_file_parent_dir)) diff --git a/tests/distributed/test_distributed_queue.py b/tests/distributed/test_distributed_queue.py index 73437fe29..511967189 100644 --- a/tests/distributed/test_distributed_queue.py +++ b/tests/distributed/test_distributed_queue.py @@ -1,6 +1,8 @@ import asyncio +import logging +logging.basicConfig(level=logging.ERROR) + import uuid -from concurrent.futures import ThreadPoolExecutor from typing import Callable import jwt @@ -15,10 +17,12 @@ from comfy.component_model.executor_types import Executor from comfy.component_model.make_mutable import make_mutable from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation, NamedQueueTuple, ExecutionStatus from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker +from comfy.distributed.executors import ContextVarExecutor from comfy.distributed.process_pool_executor import ProcessPoolExecutor from comfy.distributed.server_stub import ServerStub + def create_test_prompt() -> QueueItem: from comfy.cmd.execution import validate_prompt @@ -37,8 +41,11 @@ async def test_sign_jwt_auth_none(): assert user_token["sub"] == client_id +_executor_factories: tuple[Executor] = (ContextVarExecutor,) + + @pytest.mark.asyncio -@pytest.mark.parametrize("executor_factory", (ThreadPoolExecutor, ProcessPoolExecutor,)) +@pytest.mark.parametrize("executor_factory", _executor_factories) async def test_basic_queue_worker(executor_factory: Callable[..., Executor]) -> None: with RabbitMqContainer("rabbitmq:latest") as rabbitmq: params = rabbitmq.get_connection_params() @@ -72,7 +79,7 @@ async def test_distributed_prompt_queues_same_process(): frontend.put(test_prompt) # start a worker thread - thread_pool = ThreadPoolExecutor(max_workers=1) + thread_pool = ContextVarExecutor(max_workers=1) async def in_thread(): incoming, incoming_prompt_id = worker.get() @@ -127,7 +134,7 @@ async def check_health(url: str, max_retries: int = 5, retry_delay: float = 1.0) @pytest.mark.asyncio -@pytest.mark.parametrize("executor_factory", (ThreadPoolExecutor, ProcessPoolExecutor,)) +@pytest.mark.parametrize("executor_factory", _executor_factories) async def test_basic_queue_worker_with_health_check(executor_factory): with RabbitMqContainer("rabbitmq:latest") as rabbitmq: params = rabbitmq.get_connection_params() diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py index a04b2c55c..ca975a7a0 100644 --- a/tests/inference/test_execution.py +++ b/tests/inference/test_execution.py @@ -1,3 +1,4 @@ +import logging import uuid from contextvars import ContextVar from typing import Dict, Optional @@ -73,7 +74,8 @@ class ComfyClient: prompt_id = str(uuid.uuid4()) try: outputs = await self.embedded_client.queue_prompt(graph.finalize(), prompt_id=prompt_id) - except (RuntimeError, DependencyCycleError): + except (RuntimeError, DependencyCycleError) as exc_info: + logging.warning("error when queueing prompt", exc_info=exc_info) outputs = {} result = RunResult(prompt_id=prompt_id) result.outputs = outputs diff --git a/tests/unit/comfy_test/folder_path_test.py b/tests/unit/comfy_test/folder_path_test.py index 82c2930e8..51f6bc31c 100644 --- a/tests/unit/comfy_test/folder_path_test.py +++ b/tests/unit/comfy_test/folder_path_test.py @@ -2,11 +2,13 @@ # TODO(yoland): clean up this after I get back down import os import tempfile -from unittest.mock import patch +from pathlib import Path import pytest from comfy.cmd import folder_paths +from comfy.component_model.folder_path_types import FolderNames, ModelPaths +from comfy.execution_context import context_folder_names_and_paths @pytest.fixture @@ -40,16 +42,6 @@ def test_add_model_folder_path(): assert "/test/path" in folder_paths.get_folder_paths("test_folder") -def test_recursive_search(temp_dir): - os.makedirs(os.path.join(temp_dir, "subdir")) - open(os.path.join(temp_dir, "file1.txt"), "w").close() - open(os.path.join(temp_dir, "subdir", "file2.txt"), "w").close() - - files, dirs = folder_paths.recursive_search(temp_dir) - assert set(files) == {"file1.txt", os.path.join("subdir", "file2.txt")} - assert len(dirs) == 2 # temp_dir and subdir - - def test_filter_files_extensions(): files = ["file1.txt", "file2.jpg", "file3.png", "file4.txt"] assert folder_paths.filter_files_extensions(files, [".txt"]) == ["file1.txt", "file4.txt"] @@ -57,16 +49,24 @@ def test_filter_files_extensions(): assert folder_paths.filter_files_extensions(files, []) == files -@patch("folder_paths.recursive_search") -@patch("folder_paths.folder_names_and_paths") -def test_get_filename_list(mock_folder_names_and_paths, mock_recursive_search): - mock_folder_names_and_paths.__getitem__.return_value = (["/test/path"], {".txt"}) - mock_recursive_search.return_value = (["file1.txt", "file2.jpg"], {}) - assert folder_paths.get_filename_list("test_folder") == ["file1.txt"] +def test_get_filename_list(temp_dir): + base_path = Path(temp_dir) + fn = FolderNames(base_paths=[base_path]) + rel_path = Path("test/path") + fn.add(ModelPaths(["test_folder"], additional_relative_directory_paths={rel_path}, supported_extensions={".txt"})) + dir_path = base_path / rel_path + Path.mkdir(dir_path, parents=True, exist_ok=True) + files = ["file1.txt", "file2.jpg"] + + for file in files: + Path.touch(dir_path / file, exist_ok=True) + + with context_folder_names_and_paths(fn): + assert folder_paths.get_filename_list("test_folder") == ["file1.txt"] def test_get_save_image_path(temp_dir): - with patch("folder_paths.output_directory", temp_dir): + with context_folder_names_and_paths(FolderNames(base_paths=[Path(temp_dir)])): full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("test", temp_dir, 100, 100) assert os.path.samefile(full_output_folder, temp_dir) assert filename == "test"