--base-paths argument adds additional paths to search for models/checkpoints, models/loras, etc. directories, including directories specified in this pattern by custom nodes

This commit is contained in:
doctorpangloss 2024-10-28 19:04:14 -07:00
parent 89d07f3adf
commit 4a13766d14
30 changed files with 932 additions and 525 deletions

101
README.md
View File

@ -722,19 +722,31 @@ You can pass additional extra model path configurations with one or more copies
### Command Line Arguments ### Command Line Arguments
``` ```
usage: comfyui.exe [-h] [-c CONFIG_FILE] [--write-out-config-file CONFIG_OUTPUT_PATH] [-w CWD] [-H [IP]] [--port PORT] [--enable-cors-header [ORIGIN]] [--max-upload-size MAX_UPLOAD_SIZE] usage: comfyui.exe [-h] [-c CONFIG_FILE] [--write-out-config-file CONFIG_OUTPUT_PATH] [-w CWD] [--base-paths BASE_PATHS [BASE_PATHS ...]] [-H [IP]] [--port PORT]
[--extra-model-paths-config PATH [PATH ...]] [--output-directory OUTPUT_DIRECTORY] [--temp-directory TEMP_DIRECTORY] [--input-directory INPUT_DIRECTORY] [--auto-launch] [--enable-cors-header [ORIGIN]] [--max-upload-size MAX_UPLOAD_SIZE] [--extra-model-paths-config PATH [PATH ...]]
[--disable-auto-launch] [--cuda-device DEVICE_ID] [--cuda-malloc | --disable-cuda-malloc] [--force-fp32 | --force-fp16 | --force-bf16] [--output-directory OUTPUT_DIRECTORY] [--temp-directory TEMP_DIRECTORY] [--input-directory INPUT_DIRECTORY] [--auto-launch] [--disable-auto-launch]
[--cuda-device DEVICE_ID] [--cuda-malloc | --disable-cuda-malloc] [--force-fp32 | --force-fp16 | --force-bf16]
[--bf16-unet | --fp16-unet | --fp8_e4m3fn-unet | --fp8_e5m2-unet] [--fp16-vae | --fp32-vae | --bf16-vae] [--cpu-vae] [--bf16-unet | --fp16-unet | --fp8_e4m3fn-unet | --fp8_e5m2-unet] [--fp16-vae | --fp32-vae | --bf16-vae] [--cpu-vae]
[--fp8_e4m3fn-text-enc | --fp8_e5m2-text-enc | --fp16-text-enc | --fp32-text-enc] [--directml [DIRECTML_DEVICE]] [--disable-ipex-optimize] [--fp8_e4m3fn-text-enc | --fp8_e5m2-text-enc | --fp16-text-enc | --fp32-text-enc] [--directml [DIRECTML_DEVICE]] [--disable-ipex-optimize]
[--preview-method [none,auto,latent2rgb,taesd]] [--use-split-cross-attention | --use-quad-cross-attention | --use-pytorch-cross-attention] [--disable-xformers] [--preview-method [none,auto,latent2rgb,taesd]] [--preview-size PREVIEW_SIZE] [--cache-lru CACHE_LRU]
[--force-upcast-attention | --dont-upcast-attention] [--gpu-only | --highvram | --normalvram | --lowvram | --novram | --cpu] [--disable-smart-memory] [--deterministic] [--use-split-cross-attention | --use-quad-cross-attention | --use-pytorch-cross-attention] [--disable-xformers] [--disable-flash-attn]
[--dont-print-server] [--quick-test-for-ci] [--windows-standalone-build] [--disable-metadata] [--multi-user] [--create-directories] [--disable-sage-attention] [--force-upcast-attention | --dont-upcast-attention]
[--plausible-analytics-base-url PLAUSIBLE_ANALYTICS_BASE_URL] [--plausible-analytics-domain PLAUSIBLE_ANALYTICS_DOMAIN] [--analytics-use-identity-provider] [--gpu-only | --highvram | --normalvram | --lowvram | --novram | --cpu] [--reserve-vram RESERVE_VRAM]
[--distributed-queue-connection-uri DISTRIBUTED_QUEUE_CONNECTION_URI] [--distributed-queue-worker] [--distributed-queue-frontend] [--distributed-queue-name DISTRIBUTED_QUEUE_NAME] [--default-hashing-function {md5,sha1,sha256,sha512}] [--disable-smart-memory] [--deterministic] [--fast] [--dont-print-server]
[--external-address EXTERNAL_ADDRESS] [--verbose] [--disable-known-models] [--max-queue-size MAX_QUEUE_SIZE] [--otel-service-name OTEL_SERVICE_NAME] [--quick-test-for-ci] [--windows-standalone-build] [--disable-metadata] [--disable-all-custom-nodes] [--multi-user] [--create-directories]
[--otel-service-version OTEL_SERVICE_VERSION] [--otel-exporter-otlp-endpoint OTEL_EXPORTER_OTLP_ENDPOINT] [--plausible-analytics-base-url PLAUSIBLE_ANALYTICS_BASE_URL] [--plausible-analytics-domain PLAUSIBLE_ANALYTICS_DOMAIN]
[--analytics-use-identity-provider] [--distributed-queue-connection-uri DISTRIBUTED_QUEUE_CONNECTION_URI] [--distributed-queue-worker]
[--distributed-queue-frontend] [--distributed-queue-name DISTRIBUTED_QUEUE_NAME] [--external-address EXTERNAL_ADDRESS]
[--logging-level {DEBUG,INFO,WARNING,ERROR,CRITICAL}] [--disable-known-models] [--max-queue-size MAX_QUEUE_SIZE]
[--otel-service-name OTEL_SERVICE_NAME] [--otel-service-version OTEL_SERVICE_VERSION] [--otel-exporter-otlp-endpoint OTEL_EXPORTER_OTLP_ENDPOINT]
[--force-channels-last] [--force-hf-local-dir-mode] [--front-end-version FRONT_END_VERSION] [--front-end-root FRONT_END_ROOT]
[--executor-factory EXECUTOR_FACTORY] [--openai-api-key OPENAI_API_KEY] [--user-directory USER_DIRECTORY] [--blip-model-url BLIP_MODEL_URL]
[--blip-model-vqa-url BLIP_MODEL_VQA_URL] [--sam-model-vith-url SAM_MODEL_VITH_URL] [--sam-model-vitl-url SAM_MODEL_VITL_URL]
[--sam-model-vitb-url SAM_MODEL_VITB_URL] [--history-display-limit HISTORY_DISPLAY_LIMIT] [--ffmpeg-bin-path FFMPEG_BIN_PATH]
[--ffmpeg-extra-codecs FFMPEG_EXTRA_CODECS] [--wildcards-path WILDCARDS_PATH] [--wildcard-api WILDCARD_API] [--photoprism-host PHOTOPRISM_HOST]
[--immich-host IMMICH_HOST] [--ideogram-session-cookie IDEOGRAM_SESSION_COOKIE] [--annotator-ckpts-path ANNOTATOR_CKPTS_PATH] [--use-symlinks]
[--ort-providers ORT_PROVIDERS] [--vfi-ops-backend VFI_OPS_BACKEND] [--dependency-version DEPENDENCY_VERSION] [--mmdet-skip] [--sam-editor-cpu]
[--sam-editor-model SAM_EDITOR_MODEL] [--custom-wildcards CUSTOM_WILDCARDS] [--disable-gpu-opencv]
options: options:
-h, --help show this help message and exit -h, --help show this help message and exit
@ -742,10 +754,14 @@ options:
config file path config file path
--write-out-config-file CONFIG_OUTPUT_PATH --write-out-config-file CONFIG_OUTPUT_PATH
takes the current command line args and writes them out to a config file at the given path, then exits takes the current command line args and writes them out to a config file at the given path, then exits
-w CWD, --cwd CWD Specify the working directory. If not set, this is the current working directory. models/, input/, output/ and other directories will be located here by default. [env var: -w CWD, --cwd CWD Specify the working directory. If not set, this is the current working directory. models/, input/, output/ and other directories will be
COMFYUI_CWD] located here by default. [env var: COMFYUI_CWD]
--base-paths BASE_PATHS [BASE_PATHS ...]
Additional base paths for custom nodes, models and inputs. [env var: COMFYUI_BASE_PATHS]
-H [IP], --listen [IP] -H [IP], --listen [IP]
Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all) [env var: COMFYUI_LISTEN] Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like:
127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6) [env var:
COMFYUI_LISTEN]
--port PORT Set the listen port. [env var: COMFYUI_PORT] --port PORT Set the listen port. [env var: COMFYUI_PORT]
--enable-cors-header [ORIGIN] --enable-cors-header [ORIGIN]
Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'. [env var: COMFYUI_ENABLE_CORS_HEADER] Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'. [env var: COMFYUI_ENABLE_CORS_HEADER]
@ -789,6 +805,10 @@ options:
Disables ipex.optimize when loading models with Intel GPUs. [env var: COMFYUI_DISABLE_IPEX_OPTIMIZE] Disables ipex.optimize when loading models with Intel GPUs. [env var: COMFYUI_DISABLE_IPEX_OPTIMIZE]
--preview-method [none,auto,latent2rgb,taesd] --preview-method [none,auto,latent2rgb,taesd]
Default preview method for sampler nodes. [env var: COMFYUI_PREVIEW_METHOD] Default preview method for sampler nodes. [env var: COMFYUI_PREVIEW_METHOD]
--preview-size PREVIEW_SIZE
Sets the maximum preview size for sampler nodes. [env var: COMFYUI_PREVIEW_SIZE]
--cache-lru CACHE_LRU
Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM. [env var: COMFYUI_CACHE_LRU]
--use-split-cross-attention --use-split-cross-attention
Use the split cross attention optimization. Ignored when xformers is used. [env var: COMFYUI_USE_SPLIT_CROSS_ATTENTION] Use the split cross attention optimization. Ignored when xformers is used. [env var: COMFYUI_USE_SPLIT_CROSS_ATTENTION]
--use-quad-cross-attention --use-quad-cross-attention
@ -796,6 +816,9 @@ options:
--use-pytorch-cross-attention --use-pytorch-cross-attention
Use the new pytorch 2.0 cross attention function. [env var: COMFYUI_USE_PYTORCH_CROSS_ATTENTION] Use the new pytorch 2.0 cross attention function. [env var: COMFYUI_USE_PYTORCH_CROSS_ATTENTION]
--disable-xformers Disable xformers. [env var: COMFYUI_DISABLE_XFORMERS] --disable-xformers Disable xformers. [env var: COMFYUI_DISABLE_XFORMERS]
--disable-flash-attn Disable Flash Attention [env var: COMFYUI_DISABLE_FLASH_ATTN]
--disable-sage-attention
Disable Sage Attention [env var: COMFYUI_DISABLE_SAGE_ATTENTION]
--force-upcast-attention --force-upcast-attention
Force enable attention upcasting, please report if it fixes black images. [env var: COMFYUI_FORCE_UPCAST_ATTENTION] Force enable attention upcasting, please report if it fixes black images. [env var: COMFYUI_FORCE_UPCAST_ATTENTION]
--dont-upcast-attention --dont-upcast-attention
@ -806,15 +829,25 @@ options:
--lowvram Split the unet in parts to use less vram. [env var: COMFYUI_LOWVRAM] --lowvram Split the unet in parts to use less vram. [env var: COMFYUI_LOWVRAM]
--novram When lowvram isn't enough. [env var: COMFYUI_NOVRAM] --novram When lowvram isn't enough. [env var: COMFYUI_NOVRAM]
--cpu To use the CPU for everything (slow). [env var: COMFYUI_CPU] --cpu To use the CPU for everything (slow). [env var: COMFYUI_CPU]
--reserve-vram RESERVE_VRAM
Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.
[env var: COMFYUI_RESERVE_VRAM]
--default-hashing-function {md5,sha1,sha256,sha512}
Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256. [env var:
COMFYUI_DEFAULT_HASHING_FUNCTION]
--disable-smart-memory --disable-smart-memory
Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can. [env var: COMFYUI_DISABLE_SMART_MEMORY] Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can. [env var: COMFYUI_DISABLE_SMART_MEMORY]
--deterministic Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases. [env var: COMFYUI_DETERMINISTIC] --deterministic Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases. [env var:
COMFYUI_DETERMINISTIC]
--fast Enable some untested and potentially quality deteriorating optimizations. [env var: COMFYUI_FAST]
--dont-print-server Don't print server output. [env var: COMFYUI_DONT_PRINT_SERVER] --dont-print-server Don't print server output. [env var: COMFYUI_DONT_PRINT_SERVER]
--quick-test-for-ci Quick test for CI. Raises an error if nodes cannot be imported, [env var: COMFYUI_QUICK_TEST_FOR_CI] --quick-test-for-ci Quick test for CI. Raises an error if nodes cannot be imported, [env var: COMFYUI_QUICK_TEST_FOR_CI]
--windows-standalone-build --windows-standalone-build
Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup). [env var: Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening
COMFYUI_WINDOWS_STANDALONE_BUILD] the page on startup). [env var: COMFYUI_WINDOWS_STANDALONE_BUILD]
--disable-metadata Disable saving prompt metadata in files. [env var: COMFYUI_DISABLE_METADATA] --disable-metadata Disable saving prompt metadata in files. [env var: COMFYUI_DISABLE_METADATA]
--disable-all-custom-nodes
Disable loading all custom nodes. [env var: COMFYUI_DISABLE_ALL_CUSTOM_NODES]
--multi-user Enables per-user storage. [env var: COMFYUI_MULTI_USER] --multi-user Enables per-user storage. [env var: COMFYUI_MULTI_USER]
--create-directories Creates the default models/, input/, output/ and temp/ directories, then exits. [env var: COMFYUI_CREATE_DIRECTORIES] --create-directories Creates the default models/, input/, output/ and temp/ directories, then exits. [env var: COMFYUI_CREATE_DIRECTORIES]
--plausible-analytics-base-url PLAUSIBLE_ANALYTICS_BASE_URL --plausible-analytics-base-url PLAUSIBLE_ANALYTICS_BASE_URL
@ -824,18 +857,19 @@ options:
--analytics-use-identity-provider --analytics-use-identity-provider
Uses platform identifiers for unique visitor analytics. [env var: COMFYUI_ANALYTICS_USE_IDENTITY_PROVIDER] Uses platform identifiers for unique visitor analytics. [env var: COMFYUI_ANALYTICS_USE_IDENTITY_PROVIDER]
--distributed-queue-connection-uri DISTRIBUTED_QUEUE_CONNECTION_URI --distributed-queue-connection-uri DISTRIBUTED_QUEUE_CONNECTION_URI
EXAMPLE: "amqp://guest:guest@127.0.0.1" - Servers and clients will connect to this AMPQ URL to form a distributed queue and exchange prompt execution requests and progress EXAMPLE: "amqp://guest:guest@127.0.0.1" - Servers and clients will connect to this AMPQ URL to form a distributed queue and exchange prompt
updates. [env var: COMFYUI_DISTRIBUTED_QUEUE_CONNECTION_URI] execution requests and progress updates. [env var: COMFYUI_DISTRIBUTED_QUEUE_CONNECTION_URI]
--distributed-queue-worker --distributed-queue-worker
Workers will pull requests off the AMQP URL. [env var: COMFYUI_DISTRIBUTED_QUEUE_WORKER] Workers will pull requests off the AMQP URL. [env var: COMFYUI_DISTRIBUTED_QUEUE_WORKER]
--distributed-queue-frontend --distributed-queue-frontend
Frontends will start the web UI and connect to the provided AMQP URL to submit prompts. [env var: COMFYUI_DISTRIBUTED_QUEUE_FRONTEND] Frontends will start the web UI and connect to the provided AMQP URL to submit prompts. [env var: COMFYUI_DISTRIBUTED_QUEUE_FRONTEND]
--distributed-queue-name DISTRIBUTED_QUEUE_NAME --distributed-queue-name DISTRIBUTED_QUEUE_NAME
This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue
user ID [env var: COMFYUI_DISTRIBUTED_QUEUE_NAME] name, followed by a '.', then the user ID [env var: COMFYUI_DISTRIBUTED_QUEUE_NAME]
--external-address EXTERNAL_ADDRESS --external-address EXTERNAL_ADDRESS
Specifies a base URL for external addresses reported by the API, such as for image paths. [env var: COMFYUI_EXTERNAL_ADDRESS] Specifies a base URL for external addresses reported by the API, such as for image paths. [env var: COMFYUI_EXTERNAL_ADDRESS]
--verbose Enables more debug prints. [env var: COMFYUI_VERBOSE] --logging-level {DEBUG,INFO,WARNING,ERROR,CRITICAL}
Set the logging level [env var: COMFYUI_LOGGING_LEVEL]
--disable-known-models --disable-known-models
Disables automatic downloads of known models and prevents them from appearing in the UI. [env var: COMFYUI_DISABLE_KNOWN_MODELS] Disables automatic downloads of known models and prevents them from appearing in the UI. [env var: COMFYUI_DISABLE_KNOWN_MODELS]
--max-queue-size MAX_QUEUE_SIZE --max-queue-size MAX_QUEUE_SIZE
@ -845,8 +879,27 @@ options:
--otel-service-version OTEL_SERVICE_VERSION --otel-service-version OTEL_SERVICE_VERSION
The version of the service or application that is generating telemetry data. [env var: OTEL_SERVICE_VERSION] The version of the service or application that is generating telemetry data. [env var: OTEL_SERVICE_VERSION]
--otel-exporter-otlp-endpoint OTEL_EXPORTER_OTLP_ENDPOINT --otel-exporter-otlp-endpoint OTEL_EXPORTER_OTLP_ENDPOINT
A base endpoint URL for any signal type, with an optionally-specified port number. Helpful for when you're sending more than one signal to the same endpoint and want one A base endpoint URL for any signal type, with an optionally-specified port number. Helpful for when you're sending more than one signal to the
environment variable to control the endpoint. [env var: OTEL_EXPORTER_OTLP_ENDPOINT] same endpoint and want one environment variable to control the endpoint. [env var: OTEL_EXPORTER_OTLP_ENDPOINT]
--force-channels-last
Force channels last format when inferencing the models. [env var: COMFYUI_FORCE_CHANNELS_LAST]
--force-hf-local-dir-mode
Download repos from huggingface.co to the models/huggingface directory with the "local_dir" argument instead of models/huggingface_cache with
the "cache_dir" argument, recreating the traditional file structure. [env var: COMFYUI_FORCE_HF_LOCAL_DIR_MODE]
--front-end-version FRONT_END_VERSION
Specifies the version of the frontend to be used. This command needs internet connectivity to query and download available frontend
implementations from GitHub releases. The version string should be in the format of: [repoOwner]/[repoName]@[version] where version is one of:
"latest" or a valid version number (e.g. "1.0.0") [env var: COMFYUI_FRONT_END_VERSION]
--front-end-root FRONT_END_ROOT
The local filesystem path to the directory where the frontend is located. Overrides --front-end-version. [env var: COMFYUI_FRONT_END_ROOT]
--executor-factory EXECUTOR_FACTORY
When running ComfyUI as a distributed worker, this specifies the kind of executor that should be used to run the actual ComfyUI workflow
worker. A ThreadPoolExecutor is the default. A ProcessPoolExecutor results in better memory management, since the process will be closed and
large, contiguous blocks of CUDA memory can be freed. [env var: COMFYUI_EXECUTOR_FACTORY]
--openai-api-key OPENAI_API_KEY
Configures the OpenAI API Key for the OpenAI nodes [env var: OPENAI_API_KEY]
--user-directory USER_DIRECTORY
Set the ComfyUI user directory with an absolute path. [env var: COMFYUI_USER_DIRECTORY]
Args that start with '--' can also be set in a config file (config.yaml or config.json or specified via -c). Config file syntax allows: key=value, flag=true, stuff=[a,b,c] (for details, see syntax at Args that start with '--' can also be set in a config file (config.yaml or config.json or specified via -c). Config file syntax allows: key=value, flag=true, stuff=[a,b,c] (for details, see syntax at
https://goo.gl/R74nmi). In general, command-line values override environment variables which override config file values which override defaults. https://goo.gl/R74nmi). In general, command-line values override environment variables which override config file values which override defaults.

View File

@ -28,6 +28,7 @@ def _create_parser() -> EnhancedConfigArgParser:
parser.add_argument('-w', "--cwd", type=str, default=None, parser.add_argument('-w', "--cwd", type=str, default=None,
help="Specify the working directory. If not set, this is the current working directory. models/, input/, output/ and other directories will be located here by default.") help="Specify the working directory. If not set, this is the current working directory. models/, input/, output/ and other directories will be located here by default.")
parser.add_argument("--base-paths", type=str, nargs='+', default=[], help="Additional base paths for custom nodes, models and inputs.")
parser.add_argument('-H', "--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::", parser.add_argument('-H', "--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::",
help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)") help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
@ -161,7 +162,7 @@ def _create_parser() -> EnhancedConfigArgParser:
help="This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID") help="This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID")
parser.add_argument("--external-address", required=False, parser.add_argument("--external-address", required=False,
help="Specifies a base URL for external addresses reported by the API, such as for image paths.") help="Specifies a base URL for external addresses reported by the API, such as for image paths.")
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level') parser.add_argument("--logging-level", type=str, default='INFO', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
parser.add_argument("--disable-known-models", action="store_true", help="Disables automatic downloads of known models and prevents them from appearing in the UI.") parser.add_argument("--disable-known-models", action="store_true", help="Disables automatic downloads of known models and prevents them from appearing in the UI.")
parser.add_argument("--max-queue-size", type=int, default=65536, help="The API will reject prompt requests if the queue's size exceeds this value.") parser.add_argument("--max-queue-size", type=int, default=65536, help="The API will reject prompt requests if the queue's size exceeds this value.")
# tracing # tracing
@ -251,11 +252,6 @@ def _parse_args(parser: Optional[argparse.ArgumentParser] = None, args_parsing:
if args.disable_auto_launch: if args.disable_auto_launch:
args.auto_launch = False args.auto_launch = False
logging_level = logging.INFO
if args.verbose:
logging_level = logging.DEBUG
logging.basicConfig(format="%(message)s", level=logging_level)
configuration_obj = Configuration(**vars(args)) configuration_obj = Configuration(**vars(args))
configuration_obj.config_files = config_files configuration_obj.config_files = config_files
assert all(isinstance(config_file, str) for config_file in config_files) assert all(isinstance(config_file, str) for config_file in config_files)

View File

@ -1,6 +1,8 @@
from __future__ import annotations from __future__ import annotations
import collections
import enum import enum
from pathlib import Path
from typing import Optional, List, Callable, Any, Union, Mapping, NamedTuple from typing import Optional, List, Callable, Any, Union, Mapping, NamedTuple
import configargparse import configargparse
@ -36,15 +38,16 @@ class Configuration(dict):
Attributes: Attributes:
config_files (Optional[List[str]]): Path to the configuration file(s) that were set in the arguments. config_files (Optional[List[str]]): Path to the configuration file(s) that were set in the arguments.
cwd (Optional[str]): Working directory. Defaults to the current directory. cwd (Optional[str]): Working directory. Defaults to the current directory. This is always treated as a base path for model files, and it will be the place where model files are downloaded.
base_paths (Optional[list[str]]): Additional base paths for custom nodes, models and inputs.
listen (str): IP address to listen on. Defaults to "127.0.0.1". listen (str): IP address to listen on. Defaults to "127.0.0.1".
port (int): Port number for the server to listen on. Defaults to 8188. port (int): Port number for the server to listen on. Defaults to 8188.
enable_cors_header (Optional[str]): Enables CORS with the specified origin. enable_cors_header (Optional[str]): Enables CORS with the specified origin.
max_upload_size (float): Maximum upload size in MB. Defaults to 100. max_upload_size (float): Maximum upload size in MB. Defaults to 100.
extra_model_paths_config (Optional[List[str]]): Extra model paths configuration files. extra_model_paths_config (Optional[List[str]]): Extra model paths configuration files.
output_directory (Optional[str]): Directory for output files. output_directory (Optional[str]): Directory for output files. This can also be a relative path to the cwd or current working directory.
temp_directory (Optional[str]): Temporary directory for processing. temp_directory (Optional[str]): Temporary directory for processing.
input_directory (Optional[str]): Directory for input files. input_directory (Optional[str]): Directory for input files. When this is a relative path, it will be looked up relative to the cwd (current working directory) and all of the base_paths.
auto_launch (bool): Auto-launch UI in the default browser. Defaults to False. auto_launch (bool): Auto-launch UI in the default browser. Defaults to False.
disable_auto_launch (bool): Disable auto-launching the browser. disable_auto_launch (bool): Disable auto-launching the browser.
cuda_device (Optional[int]): CUDA device ID. None means default device. cuda_device (Optional[int]): CUDA device ID. None means default device.
@ -87,7 +90,6 @@ class Configuration(dict):
reserve_vram (Optional[float]): Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS reserve_vram (Optional[float]): Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS
disable_smart_memory (bool): Disable smart memory management. disable_smart_memory (bool): Disable smart memory management.
deterministic (bool): Use deterministic algorithms where possible. deterministic (bool): Use deterministic algorithms where possible.
dont_print_server (bool): Suppress server output.
quick_test_for_ci (bool): Enable quick testing mode for CI. quick_test_for_ci (bool): Enable quick testing mode for CI.
windows_standalone_build (bool): Enable features for standalone Windows build. windows_standalone_build (bool): Enable features for standalone Windows build.
disable_metadata (bool): Disable saving metadata with outputs. disable_metadata (bool): Disable saving metadata with outputs.
@ -103,7 +105,7 @@ class Configuration(dict):
distributed_queue_worker (bool): Workers will pull requests off the AMQP URL. distributed_queue_worker (bool): Workers will pull requests off the AMQP URL.
distributed_queue_name (str): This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID. distributed_queue_name (str): This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID.
external_address (str): Specifies a base URL for external addresses reported by the API, such as for image paths. external_address (str): Specifies a base URL for external addresses reported by the API, such as for image paths.
verbose (bool | str): Shows extra output for debugging purposes such as import errors of custom nodes; or, specifies a log level logging_level (str): Specifies a log level
disable_known_models (bool): Disables automatic downloads of known models and prevents them from appearing in the UI. disable_known_models (bool): Disables automatic downloads of known models and prevents them from appearing in the UI.
max_queue_size (int): The API will reject prompt requests if the queue's size exceeds this value. max_queue_size (int): The API will reject prompt requests if the queue's size exceeds this value.
otel_service_name (str): The name of the service or application that is generating telemetry data. Default: "comfyui". otel_service_name (str): The name of the service or application that is generating telemetry data. Default: "comfyui".
@ -122,6 +124,7 @@ class Configuration(dict):
self._observers: List[ConfigObserver] = [] self._observers: List[ConfigObserver] = []
self.config_files = [] self.config_files = []
self.cwd: Optional[str] = None self.cwd: Optional[str] = None
self.base_paths: list[Path] = []
self.listen: str = "127.0.0.1" self.listen: str = "127.0.0.1"
self.port: int = 8188 self.port: int = 8188
self.enable_cors_header: Optional[str] = None self.enable_cors_header: Optional[str] = None
@ -192,7 +195,7 @@ class Configuration(dict):
self.force_channels_last: bool = False self.force_channels_last: bool = False
self.force_hf_local_dir_mode = False self.force_hf_local_dir_mode = False
self.preview_size: int = 512 self.preview_size: int = 512
self.verbose: str | bool = "INFO" self.logging_level: str = "INFO"
# from guill # from guill
self.cache_lru: int = 0 self.cache_lru: int = 0
@ -253,6 +256,17 @@ class Configuration(dict):
self.update(state) self.update(state)
self._observers = [] self._observers = []
@property
def verbose(self) -> str:
return self.logging_level
@verbose.setter
def verbose(self, value):
if isinstance(value, bool):
self.logging_level = "DEBUG"
else:
self.logging_level = value
class EnumAction(argparse.Action): class EnumAction(argparse.Action):
""" """

View File

@ -1,13 +1,14 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import contextvars
import gc import gc
import json import json
import os
import threading
import uuid import uuid
from asyncio import get_event_loop from asyncio import get_event_loop
from concurrent.futures import ThreadPoolExecutor
from multiprocessing import RLock from multiprocessing import RLock
from pathlib import Path
from typing import Optional from typing import Optional
from opentelemetry import context, propagate from opentelemetry import context, propagate
@ -17,13 +18,16 @@ from opentelemetry.trace import Status, StatusCode
from .client_types import V1QueuePromptResponse from .client_types import V1QueuePromptResponse
from ..api.components.schema.prompt import PromptDict from ..api.components.schema.prompt import PromptDict
from ..cli_args_types import Configuration from ..cli_args_types import Configuration
from ..cmd.folder_paths import init_default_paths
from ..cmd.main_pre import tracer from ..cmd.main_pre import tracer
from ..component_model.executor_types import ExecutorToClientProgress, Executor from ..component_model.executor_types import ExecutorToClientProgress
from ..component_model.make_mutable import make_mutable from ..component_model.make_mutable import make_mutable
from ..distributed.executors import ContextVarExecutor
from ..distributed.process_pool_executor import ProcessPoolExecutor from ..distributed.process_pool_executor import ProcessPoolExecutor
from ..distributed.server_stub import ServerStub from ..distributed.server_stub import ServerStub
from ..execution_context import current_execution_context
_prompt_executor = contextvars.ContextVar('prompt_executor') _prompt_executor = threading.local()
def _execute_prompt( def _execute_prompt(
@ -33,6 +37,9 @@ def _execute_prompt(
span_context: dict, span_context: dict,
progress_handler: ExecutorToClientProgress | None, progress_handler: ExecutorToClientProgress | None,
configuration: Configuration | None) -> dict: configuration: Configuration | None) -> dict:
execution_context = current_execution_context()
if len(execution_context.folder_names_and_paths) == 0 or configuration is not None:
init_default_paths(execution_context.folder_names_and_paths, configuration)
span_context: Context = propagate.extract(span_context) span_context: Context = propagate.extract(span_context)
token = attach(span_context) token = attach(span_context)
try: try:
@ -52,8 +59,8 @@ def __execute_prompt(
progress_handler = progress_handler or ServerStub() progress_handler = progress_handler or ServerStub()
try: try:
prompt_executor = _prompt_executor.get() prompt_executor: PromptExecutor = _prompt_executor.executor
except LookupError: except (LookupError, AttributeError):
if configuration is None: if configuration is None:
options.enable_args_parsing() options.enable_args_parsing()
else: else:
@ -65,7 +72,7 @@ def __execute_prompt(
with tracer.start_as_current_span("Initialize Prompt Executor", context=span_context) as span: with tracer.start_as_current_span("Initialize Prompt Executor", context=span_context) as span:
prompt_executor = PromptExecutor(progress_handler, lru_size=configuration.cache_lru if configuration is not None else 0) prompt_executor = PromptExecutor(progress_handler, lru_size=configuration.cache_lru if configuration is not None else 0)
prompt_executor.raise_exceptions = True prompt_executor.raise_exceptions = True
_prompt_executor.set(prompt_executor) _prompt_executor.executor = prompt_executor
with tracer.start_as_current_span("Execute Prompt", context=span_context) as span: with tracer.start_as_current_span("Execute Prompt", context=span_context) as span:
try: try:
@ -96,6 +103,13 @@ def __execute_prompt(
def _cleanup(): def _cleanup():
from ..cmd.execution import PromptExecutor
try:
prompt_executor: PromptExecutor = _prompt_executor.executor
# this should clear all references to output tensors and make it easier to collect back the memory
prompt_executor.reset()
except (LookupError, AttributeError):
pass
from .. import model_management from .. import model_management
model_management.unload_all_models() model_management.unload_all_models()
gc.collect() gc.collect()
@ -139,9 +153,9 @@ class EmbeddedComfyClient:
In order to use this in blocking methods, learn more about asyncio online. In order to use this in blocking methods, learn more about asyncio online.
""" """
def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: Executor = None): def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: ProcessPoolExecutor | ContextVarExecutor = None):
self._progress_handler = progress_handler or ServerStub() self._progress_handler = progress_handler or ServerStub()
self._executor = executor or ThreadPoolExecutor(max_workers=max_workers) self._executor = executor or ContextVarExecutor(max_workers=max_workers)
self._configuration = configuration self._configuration = configuration
self._is_running = False self._is_running = False
self._task_count_lock = RLock() self._task_count_lock = RLock()

View File

@ -27,7 +27,7 @@ from ..component_model.executor_types import ExecutorToClientProgress, Validatio
HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage
from ..component_model.files import canonicalize_path from ..component_model.files import canonicalize_path
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
from ..execution_context import new_execution_context, context_execute_node, ExecutionContext from ..execution_context import context_execute_node, context_execute_prompt
from ..nodes.package import import_all_nodes_in_workspace from ..nodes.package import import_all_nodes_in_workspace
from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
@ -77,24 +77,19 @@ class IsChangedCache:
class CacheSet: class CacheSet:
def __init__(self, lru_size=None): def __init__(self, lru_size=None):
if lru_size is None or lru_size == 0: if lru_size is None or lru_size == 0:
self.init_classic_cache() # Performs like the old cache -- dump data ASAP
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID)
else: else:
self.init_lru_cache(lru_size) # Useful for those with ample RAM/VRAM -- allows experimenting without
# blowing away the cache every time
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=lru_size)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=lru_size)
self.objects = HierarchicalCache(CacheKeySetID)
self.all = [self.outputs, self.ui, self.objects] self.all = [self.outputs, self.ui, self.objects]
# Useful for those with ample RAM/VRAM -- allows experimenting without
# blowing away the cache every time
def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)
# Performs like the old cache -- dump data ASAP
def init_classic_cache(self):
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID)
def recursive_debug_dump(self): def recursive_debug_dump(self):
result = { result = {
"outputs": self.outputs.recursive_debug_dump(), "outputs": self.outputs.recursive_debug_dump(),
@ -308,11 +303,11 @@ def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches,
:param pending_subgraph_results: :param pending_subgraph_results:
:return: :return:
""" """
with context_execute_node(_node_id, prompt_id): with context_execute_node(_node_id):
return _execute(server, dynprompt, caches, _node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) return _execute(server, dynprompt, caches, _node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
def _execute(server, dynprompt, caches, current_item: str, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) -> RecursiveExecutionTuple: def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) -> RecursiveExecutionTuple:
unique_id = current_item unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id) real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id) display_node_id = dynprompt.get_display_node_id(unique_id)
@ -548,7 +543,7 @@ class PromptExecutor:
# torchao and potentially other optimization approaches break when the models are created in inference mode # torchao and potentially other optimization approaches break when the models are created in inference mode
# todo: this should really be backpropagated to code which creates ModelPatchers via lazy evaluation rather than globally checked here # todo: this should really be backpropagated to code which creates ModelPatchers via lazy evaluation rather than globally checked here
inference_mode = all(not hasattr(node_class, "INFERENCE_MODE") or node_class.INFERENCE_MODE for node_class in iterate_obj_classes(prompt)) inference_mode = all(not hasattr(node_class, "INFERENCE_MODE") or node_class.INFERENCE_MODE for node_class in iterate_obj_classes(prompt))
with new_execution_context(ExecutionContext(self.server, task_id=prompt_id, inference_mode=inference_mode)): with context_execute_prompt(self.server, prompt_id, inference_mode=inference_mode):
self._execute_inner(prompt, prompt_id, extra_data, execute_outputs) self._execute_inner(prompt, prompt_id, extra_data, execute_outputs)
def _execute_inner(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None, inference_mode: bool = True): def _execute_inner(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None, inference_mode: bool = True):

View File

@ -4,142 +4,144 @@ import logging
import mimetypes import mimetypes
import os import os
import time import time
from typing import Optional, List, Final, Literal from contextlib import nullcontext
from pathlib import Path
from typing import Optional, List, Literal
from .folder_paths_pre import get_base_path from ..cli_args_types import Configuration
from ..component_model.deprecation import _deprecate_method
from ..component_model.files import get_package_as_path from ..component_model.files import get_package_as_path
from ..component_model.folder_path_types import FolderPathsTuple, FolderNames, SaveImagePathResponse from ..component_model.folder_path_types import FolderNames, SaveImagePathTuple, ModelPaths
from ..component_model.folder_path_types import extension_mimetypes_cache as _extension_mimetypes_cache from ..component_model.folder_path_types import supported_pt_extensions, extension_mimetypes_cache
from ..component_model.folder_path_types import supported_pt_extensions as _supported_pt_extensions from ..component_model.module_property import create_module_properties
from ..component_model.module_property import module_property from ..component_model.platform_path import construct_path
from ..execution_context import current_execution_context
supported_pt_extensions: Final[frozenset[str]] = _supported_pt_extensions _module_properties = create_module_properties()
extension_mimetypes_cache: Final[dict[str, str]] = _extension_mimetypes_cache
@_module_properties.getter
def _supported_pt_extensions() -> frozenset[str]:
return supported_pt_extensions
@_module_properties.getter
def _extension_mimetypes_cache() -> dict[str, str]:
return extension_mimetypes_cache
# todo: this needs to be wrapped in a context and configurable # todo: this needs to be wrapped in a context and configurable
@module_property @_module_properties.getter
def _base_path(): def _base_path():
return get_base_path() return _folder_names_and_paths().base_paths[0]
models_dir = os.path.join(get_base_path(), "models") def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optional[Configuration] = None):
folder_names_and_paths: Final[FolderNames] = FolderNames(models_dir) from ..cmd.main_pre import args
folder_names_and_paths["checkpoints"] = FolderPathsTuple("checkpoints", [os.path.join(models_dir, "checkpoints")], set(supported_pt_extensions)) configuration = configuration or args
folder_names_and_paths["configs"] = FolderPathsTuple("configs", [os.path.join(models_dir, "configs"), get_package_as_path("comfy.configs")], {".yaml"}) base_paths = [Path(configuration.cwd) if configuration.cwd is not None else None] + configuration.base_paths
folder_names_and_paths["loras"] = FolderPathsTuple("loras", [os.path.join(models_dir, "loras")], set(supported_pt_extensions)) base_paths = [path for path in base_paths if path is not None]
folder_names_and_paths["vae"] = FolderPathsTuple("vae", [os.path.join(models_dir, "vae")], set(supported_pt_extensions)) if len(base_paths) == 0:
folder_names_and_paths["clip"] = FolderPathsTuple("clip", [os.path.join(models_dir, "clip")], set(supported_pt_extensions)) base_paths = [Path(os.getcwd())]
folder_names_and_paths["unet"] = folder_names_and_paths["diffusion_models"] = FolderPathsTuple("diffusion_models", [os.path.join(models_dir, "unet"), os.path.join(models_dir, "diffusion_models")], set(supported_pt_extensions)) for base_path in base_paths:
folder_names_and_paths["clip_vision"] = FolderPathsTuple("clip_vision", [os.path.join(models_dir, "clip_vision")], set(supported_pt_extensions)) folder_names_and_paths.add_base_path(base_path)
folder_names_and_paths["style_models"] = FolderPathsTuple("style_models", [os.path.join(models_dir, "style_models")], set(supported_pt_extensions)) folder_names_and_paths.add(ModelPaths(["checkpoints"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths["embeddings"] = FolderPathsTuple("embeddings", [os.path.join(models_dir, "embeddings")], set(supported_pt_extensions)) folder_names_and_paths.add(ModelPaths(["configs"], additional_absolute_directory_paths={get_package_as_path("comfy.configs")}, supported_extensions={".yaml"}))
folder_names_and_paths["diffusers"] = FolderPathsTuple("diffusers", [os.path.join(models_dir, "diffusers")], {"folder"}) folder_names_and_paths.add(ModelPaths(["vae"], supported_extensions={".yaml"}))
folder_names_and_paths["vae_approx"] = FolderPathsTuple("vae_approx", [os.path.join(models_dir, "vae_approx")], set(supported_pt_extensions)) folder_names_and_paths.add(ModelPaths(["clip"], supported_extensions={".yaml"}))
folder_names_and_paths["controlnet"] = FolderPathsTuple("controlnet", [os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], set(supported_pt_extensions)) folder_names_and_paths.add(ModelPaths(["loras"], supported_extensions={".yaml"}))
folder_names_and_paths["gligen"] = FolderPathsTuple("gligen", [os.path.join(models_dir, "gligen")], set(supported_pt_extensions)) folder_names_and_paths.add(ModelPaths(["diffusion_models", "unet"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths["upscale_models"] = FolderPathsTuple("upscale_models", [os.path.join(models_dir, "upscale_models")], set(supported_pt_extensions)) folder_names_and_paths.add(ModelPaths(["clip_vision"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths["custom_nodes"] = FolderPathsTuple("custom_nodes", [os.path.join(get_base_path(), "custom_nodes")], set()) folder_names_and_paths.add(ModelPaths(["style_models"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths["hypernetworks"] = FolderPathsTuple("hypernetworks", [os.path.join(models_dir, "hypernetworks")], set(supported_pt_extensions)) folder_names_and_paths.add(ModelPaths(["embeddings"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths["photomaker"] = FolderPathsTuple("photomaker", [os.path.join(models_dir, "photomaker")], set(supported_pt_extensions)) folder_names_and_paths.add(ModelPaths(["diffusers"], supported_extensions=set()))
folder_names_and_paths["classifiers"] = FolderPathsTuple("classifiers", [os.path.join(models_dir, "classifiers")], {""}) folder_names_and_paths.add(ModelPaths(["vae_approx"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths["huggingface"] = FolderPathsTuple("huggingface", [os.path.join(models_dir, "huggingface")], {""}) folder_names_and_paths.add(ModelPaths(["controlnet", "t2i_adapter"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths["huggingface_cache"] = FolderPathsTuple("huggingface_cache", [os.path.join(models_dir, "huggingface_cache")], {""}) folder_names_and_paths.add(ModelPaths(["gligen"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["upscale_models"], supported_extensions=set(supported_pt_extensions)))
output_directory = os.path.join(get_base_path(), "output") folder_names_and_paths.add(ModelPaths(["custom_nodes"], folder_name_base_path_subdir=construct_path(""), supported_extensions=set()))
temp_directory = os.path.join(get_base_path(), "temp") folder_names_and_paths.add(ModelPaths(["hypernetworks"], supported_extensions=set(supported_pt_extensions)))
input_directory = os.path.join(get_base_path(), "input") folder_names_and_paths.add(ModelPaths(["photomaker"], supported_extensions=set(supported_pt_extensions)))
user_directory = os.path.join(get_base_path(), "user") folder_names_and_paths.add(ModelPaths(["classifiers"], supported_extensions=set()))
folder_names_and_paths.add(ModelPaths(["huggingface"], supported_extensions=set()))
filename_list_cache = {} hf_cache_paths = ModelPaths(["huggingface_cache"], supported_extensions=set())
# TODO: explore if there is a better way to do this
if "HF_HUB_CACHE" in os.environ:
hf_cache_paths.additional_absolute_directory_paths.add(os.environ.get("HF_HUB_CACHE"))
folder_names_and_paths.add(hf_cache_paths)
create_directories(folder_names_and_paths)
class CacheHelper: @_module_properties.getter
""" def _folder_names_and_paths():
Helper class for managing file list cache data. return current_execution_context().folder_names_and_paths
"""
def __init__(self):
self.cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
self.active = False
def get(self, key: str, default=None) -> tuple[list[str], dict[str, float], float]:
if not self.active:
return default
return self.cache.get(key, default)
def set(self, key: str, value: tuple[list[str], dict[str, float], float]) -> None:
if self.active:
self.cache[key] = value
def clear(self):
self.cache.clear()
def __enter__(self):
self.active = True
return self
def __exit__(self, exc_type, exc_value, traceback):
self.active = False
self.clear()
cache_helper = CacheHelper() @_module_properties.getter
def _models_dir():
return str(Path(current_execution_context().folder_names_and_paths.base_paths[0]) / construct_path("models"))
@_module_properties.getter
def _user_directory() -> str:
return str(Path(current_execution_context().folder_names_and_paths.application_paths.user_directory).resolve())
@_module_properties.getter
def _temp_directory() -> str:
return str(Path(current_execution_context().folder_names_and_paths.application_paths.temp_directory).resolve())
@_module_properties.getter
def _input_directory() -> str:
return str(Path(current_execution_context().folder_names_and_paths.application_paths.input_directory).resolve())
@_module_properties.getter
def _output_directory() -> str:
return str(Path(current_execution_context().folder_names_and_paths.application_paths.output_directory).resolve())
@_deprecate_method(version="0.2.3", message="Mapping of previous folder names is already done by other mechanisms.")
def map_legacy(folder_name: str) -> str: def map_legacy(folder_name: str) -> str:
legacy = {"unet": "diffusion_models"} legacy = {"unet": "diffusion_models"}
return legacy.get(folder_name, folder_name) return legacy.get(folder_name, folder_name)
if not os.path.exists(input_directory): def set_output_directory(output_dir: str | Path):
try: _folder_names_and_paths().application_paths.output_directory = construct_path(output_dir)
os.makedirs(input_directory)
except:
logging.error("Failed to create input directory")
def set_output_directory(output_dir): def set_temp_directory(temp_dir: str | Path):
global output_directory _folder_names_and_paths().application_paths.temp_directory = construct_path(temp_dir)
output_directory = output_dir
def set_temp_directory(temp_dir): def set_input_directory(input_dir: str | Path):
global temp_directory _folder_names_and_paths().application_paths.input_directory = construct_path(input_dir)
temp_directory = temp_dir
def set_input_directory(input_dir): def get_output_directory() -> str:
global input_directory return str(Path(_folder_names_and_paths().application_paths.output_directory).resolve())
input_directory = input_dir
def get_output_directory(): def get_temp_directory() -> str:
global output_directory return str(Path(_folder_names_and_paths().application_paths.temp_directory).resolve())
return output_directory
def get_temp_directory(): def get_input_directory() -> str:
global temp_directory return str(Path(_folder_names_and_paths().application_paths.input_directory).resolve())
return temp_directory
def get_input_directory():
global input_directory
return input_directory
def get_user_directory() -> str: def get_user_directory() -> str:
return user_directory return str(Path(_folder_names_and_paths().application_paths.user_directory).resolve())
def set_user_directory(user_dir: str) -> None: def set_user_directory(user_dir: str | Path) -> None:
global user_directory _folder_names_and_paths().application_paths.user_directory = construct_path(user_dir)
user_directory = user_dir
# NOTE: used in http server so don't put folders that should not be accessed remotely # NOTE: used in http server so don't put folders that should not be accessed remotely
def get_directory_by_type(type_name): def get_directory_by_type(type_name) -> str | None:
if type_name == "output": if type_name == "output":
return get_output_directory() return get_output_directory()
if type_name == "temp": if type_name == "temp":
@ -151,7 +153,7 @@ def get_directory_by_type(type_name):
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format # determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
# otherwise use default_path as base_dir # otherwise use default_path as base_dir
def annotated_filepath(name): def annotated_filepath(name: str) -> tuple[str, str | None]:
if name.endswith("[output]"): if name.endswith("[output]"):
base_dir = get_output_directory() base_dir = get_output_directory()
name = name[:-9] name = name[:-9]
@ -167,7 +169,7 @@ def annotated_filepath(name):
return name, base_dir return name, base_dir
def get_annotated_filepath(name, default_dir=None): def get_annotated_filepath(name, default_dir=None) -> str:
name, base_dir = annotated_filepath(name) name, base_dir = annotated_filepath(name)
if base_dir is None: if base_dir is None:
@ -198,9 +200,11 @@ def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None, e
:param extensions: supported file extensions :param extensions: supported file extensions
:return: the folder path :return: the folder path
""" """
global folder_names_and_paths folder_names_and_paths = _folder_names_and_paths()
if full_folder_path is None: if full_folder_path is None:
full_folder_path = os.path.join(models_dir, folder_name) # todo: this should use the subdir patter
full_folder_path = os.path.join(_models_dir(), folder_name)
folder_path = folder_names_and_paths[folder_name] folder_path = folder_names_and_paths[folder_name]
if full_folder_path not in folder_path.paths: if full_folder_path not in folder_path.paths:
@ -212,48 +216,16 @@ def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None, e
if extensions is not None: if extensions is not None:
folder_path.supported_extensions |= extensions folder_path.supported_extensions |= extensions
invalidate_cache(folder_name)
return full_folder_path return full_folder_path
def get_folder_paths(folder_name) -> List[str]: def get_folder_paths(folder_name) -> List[str]:
return folder_names_and_paths[folder_name].paths[:] return [path for path in _folder_names_and_paths()[folder_name].paths]
def recursive_search(directory, excluded_dir_names=None): @_deprecate_method(version="0.2.3", message="Not supported")
if not os.path.isdir(directory): def recursive_search(directory, excluded_dir_names=None) -> tuple[list[str], dict[str, float]]:
return [], {} raise NotImplemented("Unsupported method")
if excluded_dir_names is None:
excluded_dir_names = []
result = []
dirs = {}
# Attempt to add the initial directory to dirs with error handling
try:
dirs[directory] = os.path.getmtime(directory)
except FileNotFoundError:
logging.warning(f"Warning: Unable to access {directory}. Skipping this path.")
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
for file_name in filenames:
try:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path)
except:
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
continue
for d in subdirs:
path = os.path.join(dirpath, d)
try:
dirs[path] = os.path.getmtime(path)
except FileNotFoundError:
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
continue
return result, dirs
def filter_files_extensions(files, extensions): def filter_files_extensions(files, extensions):
@ -264,92 +236,27 @@ def get_full_path(folder_name, filename) -> Optional[str | bytes | os.PathLike]:
""" """
Gets the path to a filename inside a folder. Gets the path to a filename inside a folder.
Works with untrusted filenames.
:param folder_name: :param folder_name:
:param filename: :param filename:
:return: :return:
""" """
global folder_names_and_paths path = _folder_names_and_paths().first_existing_or_none(folder_name, construct_path(filename))
folders = folder_names_and_paths[folder_name].paths return str(path) if path is not None else None
filename_split = os.path.split(filename)
trusted_paths = []
for folder in folders:
folder_split = os.path.split(folder)
abs_file_path = os.path.abspath(os.path.join(*folder_split, *filename_split))
abs_folder_path = os.path.abspath(folder)
if os.path.commonpath([abs_file_path, abs_folder_path]) == abs_folder_path:
trusted_paths.append(abs_file_path)
else:
logging.error(f"attempted to access untrusted path {abs_file_path} in {folder_name} for filename {filename}")
for trusted_path in trusted_paths:
if os.path.isfile(trusted_path):
return trusted_path
return None
def get_full_path_or_raise(folder_name: str, filename: str) -> str: def get_full_path_or_raise(folder_name: str, filename: str) -> str:
full_path = get_full_path(folder_name, filename) full_path = get_full_path(folder_name, filename)
if full_path is None: if full_path is None:
# todo: probably shouldn't say model
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.") raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
return full_path return full_path
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
folder_name = map_legacy(folder_name)
global folder_names_and_paths
output_list = set()
folders = folder_names_and_paths[folder_name]
output_folders = {}
for x in folders[0]:
files, folders_all = recursive_search(x, excluded_dir_names=[".git"])
output_list.update(filter_files_extensions(files, folders[1]))
output_folders = {**output_folders, **folders_all}
return sorted(list(output_list)), output_folders, time.perf_counter()
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
strong_cache = cache_helper.get(folder_name)
if strong_cache is not None:
return strong_cache
global filename_list_cache
global folder_names_and_paths
folder_name = map_legacy(folder_name)
if folder_name not in filename_list_cache:
return None
out = filename_list_cache[folder_name]
for x in out[1]:
time_modified = out[1][x]
folder = x
if os.path.getmtime(folder) != time_modified:
return None
folders = folder_names_and_paths[folder_name]
for x in folders[0]:
if os.path.isdir(x):
if x not in out[1]:
return None
return out
def get_filename_list(folder_name: str) -> list[str]: def get_filename_list(folder_name: str) -> list[str]:
folder_name = map_legacy(folder_name) return [str(path) for path in _folder_names_and_paths().file_paths(folder_name=folder_name, relative=True)]
out = cached_filename_list_(folder_name)
if out is None:
out = get_filename_list_(folder_name)
global filename_list_cache
filename_list_cache[folder_name] = out
cache_helper.set(folder_name, out)
return list(out[0])
def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0): def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0) -> SaveImagePathTuple:
def map_filename(filename: str) -> tuple[int, str]: def map_filename(filename: str) -> tuple[int, str]:
prefix_len = len(os.path.basename(filename_prefix)) prefix_len = len(os.path.basename(filename_prefix))
prefix = filename[:prefix_len + 1] prefix = filename[:prefix_len + 1]
@ -378,14 +285,6 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height
full_output_folder = str(os.path.join(output_dir, subfolder)) full_output_folder = str(os.path.join(output_dir, subfolder))
if str(os.path.commonpath((output_dir, os.path.abspath(full_output_folder)))) != str(output_dir):
err = f"""**** ERROR: Saving image outside the output folder is not allowed.
full_output_folder: {os.path.abspath(full_output_folder)}
output_dir: {output_dir}
commonpath: {os.path.commonpath((output_dir, os.path.abspath(full_output_folder)))}"""
logging.error(err)
raise Exception(err)
try: try:
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1 counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
except ValueError: except ValueError:
@ -393,21 +292,22 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height
except FileNotFoundError: except FileNotFoundError:
os.makedirs(full_output_folder, exist_ok=True) os.makedirs(full_output_folder, exist_ok=True)
counter = 1 counter = 1
return SaveImagePathResponse(full_output_folder, filename, counter, subfolder, filename_prefix) return SaveImagePathTuple(full_output_folder, filename, counter, subfolder, filename_prefix)
def create_directories(): def create_directories(paths: FolderNames | None):
# all configured paths should be created # all configured paths should be created
for folder_path_spec in folder_names_and_paths.values(): paths = paths or _folder_names_and_paths()
for folder_path_spec in paths.values():
for path in folder_path_spec.paths: for path in folder_path_spec.paths:
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
for path in (temp_directory, input_directory, output_directory, user_directory): for path in paths.application_paths:
os.makedirs(path, exist_ok=True) path.mkdir(exist_ok=True)
@_deprecate_method(version="0.2.3", message="Caching has been removed.")
def invalidate_cache(folder_name): def invalidate_cache(folder_name):
global filename_list_cache pass
filename_list_cache.pop(folder_name, None)
def filter_files_content_types(files: list[str], content_types: list[Literal["image", "video", "audio"]]) -> list[str]: def filter_files_content_types(files: list[str], content_types: list[Literal["image", "video", "audio"]]) -> list[str]:
@ -416,7 +316,7 @@ def filter_files_content_types(files: list[str], content_types: list[Literal["im
files = os.listdir(folder_paths.get_input_directory()) files = os.listdir(folder_paths.get_input_directory())
filter_files_content_types(files, ["image", "audio", "video"]) filter_files_content_types(files, ["image", "audio", "video"])
""" """
global extension_mimetypes_cache extension_mimetypes_cache = _extension_mimetypes_cache()
result = [] result = []
for file in files: for file in files:
extension = file.split('.')[-1] extension = file.split('.')[-1]
@ -432,3 +332,52 @@ def filter_files_content_types(files: list[str], content_types: list[Literal["im
if content_type in content_types: if content_type in content_types:
result.append(file) result.append(file)
return result return result
@_module_properties.getter
def _cache_helper():
return nullcontext()
# todo: can this be done side effect free?
init_default_paths(_folder_names_and_paths())
__all__ = [
# Properties (stripped leading underscore)
"supported_pt_extensions", # from _supported_pt_extensions
"extension_mimetypes_cache", # from _extension_mimetypes_cache
"base_path", # from _base_path
"folder_names_and_paths", # from _folder_names_and_paths
"models_dir", # from _models_dir
"user_directory",
"output_directory",
"temp_directory",
"input_directory",
# Public functions
"init_default_paths",
"map_legacy",
"set_output_directory",
"set_temp_directory",
"set_input_directory",
"get_output_directory",
"get_temp_directory",
"get_input_directory",
"get_user_directory",
"set_user_directory",
"get_directory_by_type",
"annotated_filepath",
"get_annotated_filepath",
"exists_annotated_filepath",
"add_model_folder_path",
"get_folder_paths",
"recursive_search",
"filter_files_extensions",
"get_full_path",
"get_full_path_or_raise",
"get_filename_list",
"get_save_image_path",
"create_directories",
"invalidate_cache",
"filter_files_content_types"
]

View File

@ -1,31 +0,0 @@
import logging
import os
from ..cli_args import args
_base_path = None
# todo: this should be initialized elsewhere in a context
def get_base_path() -> str:
global _base_path
if _base_path is None:
if args.cwd is not None:
if not os.path.exists(args.cwd):
try:
os.makedirs(args.cwd, exist_ok=True)
except:
logging.error("Failed to create custom working directory")
# wrap the path to prevent slashedness from glitching out common path checks
_base_path = os.path.realpath(args.cwd)
else:
_base_path = os.getcwd()
return _base_path
def set_base_path(value: str):
global _base_path
_base_path = value
__all__ = ["get_base_path", "set_base_path"]

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import contextvars
import gc import gc
import itertools import itertools
import logging import logging
@ -193,7 +194,9 @@ async def main(from_script_dir: Optional[Path] = None):
if not distributed or args.distributed_queue_worker: if not distributed or args.distributed_queue_worker:
if distributed: if distributed:
logging.warning(f"Distributed workers started in the default thread loop cannot notify clients of progress updates. Instead of comfyui or main.py, use comfyui-worker.") logging.warning(f"Distributed workers started in the default thread loop cannot notify clients of progress updates. Instead of comfyui or main.py, use comfyui-worker.")
threading.Thread(target=prompt_worker, daemon=True, args=(q, worker_thread_server,)).start() # todo: this should really be using an executor instead of doing things this jankilicious way
ctx = contextvars.copy_context()
threading.Thread(target=lambda _q, _worker_thread_server: ctx.run(prompt_worker, _q, _worker_thread_server), daemon=True, args=(q, worker_thread_server,)).start()
# server has been imported and things should be looking good # server has been imported and things should be looking good
initialize_event_tracking(loop) initialize_event_tracking(loop)

View File

@ -111,13 +111,7 @@ def _create_tracer():
def _configure_logging(): def _configure_logging():
if isinstance(args.verbose, str): logging_level = args.logging_level
logging_level = args.verbose
elif args.verbose == True:
logging_level = logging.DEBUG
else:
logging_level = logging.ERROR
logging.basicConfig(format="%(message)s", level=logging_level) logging.basicConfig(format="%(message)s", level=logging_level)

View File

@ -28,8 +28,8 @@ from aiohttp import web
from can_ada import URL, parse as urlparse # pylint: disable=no-name-in-module from can_ada import URL, parse as urlparse # pylint: disable=no-name-in-module
from typing_extensions import NamedTuple from typing_extensions import NamedTuple
from .. import __version__
from .latent_preview_image_encoding import encode_preview_image from .latent_preview_image_encoding import encode_preview_image
from .. import __version__
from .. import interruption, model_management from .. import interruption, model_management
from .. import node_helpers from .. import node_helpers
from .. import utils from .. import utils
@ -241,7 +241,7 @@ class PromptServer(ExecutorToClientProgress):
return response return response
@routes.get("/embeddings") @routes.get("/embeddings")
def get_embeddings(self): def get_embeddings(request):
embeddings = folder_paths.get_filename_list("embeddings") embeddings = folder_paths.get_filename_list("embeddings")
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings))) return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
@ -568,15 +568,14 @@ class PromptServer(ExecutorToClientProgress):
@routes.get("/object_info") @routes.get("/object_info")
async def get_object_info(request): async def get_object_info(request):
with folder_paths.cache_helper: out = {}
out = {} for x in self.nodes.NODE_CLASS_MAPPINGS:
for x in self.nodes.NODE_CLASS_MAPPINGS: try:
try: out[x] = node_info(x)
out[x] = node_info(x) except Exception as e:
except Exception as e: logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.") logging.error(traceback.format_exc())
logging.error(traceback.format_exc()) return web.json_response(out)
return web.json_response(out)
@routes.get("/object_info/{node_class}") @routes.get("/object_info/{node_class}")
async def get_object_info_node(request): async def get_object_info_node(request):
@ -721,21 +720,21 @@ class PromptServer(ExecutorToClientProgress):
async def get_api_v1_prompts_prompt_id(request: web.Request) -> web.Response | web.FileResponse: async def get_api_v1_prompts_prompt_id(request: web.Request) -> web.Response | web.FileResponse:
prompt_id: str = request.match_info.get("prompt_id", "") prompt_id: str = request.match_info.get("prompt_id", "")
if prompt_id == "": if prompt_id == "":
return web.Response(status=404) return web.json_response(status=404)
history_items = self.prompt_queue.get_history(prompt_id) history_items = self.prompt_queue.get_history(prompt_id)
if len(history_items) == 0 or prompt_id not in history_items: if len(history_items) == 0 or prompt_id not in history_items:
# todo: this should really be moved to a stateful queue abstraction # todo: this should really be moved to a stateful queue abstraction
if prompt_id in self.background_tasks: if prompt_id in self.background_tasks:
return web.Response(status=204) return web.json_response(status=204)
else: else:
# todo: this should check a stateful queue abstraction # todo: this should check a stateful queue abstraction
return web.Response(status=404) return web.json_response(status=404)
elif prompt_id in history_items: elif prompt_id in history_items:
history_entry = history_items[prompt_id] history_entry = history_items[prompt_id]
return web.json_response(history_entry["outputs"]) return web.json_response(history_entry["outputs"])
else: else:
return web.Response(status=500) return web.json_response(status=500)
@routes.post("/api/v1/prompts") @routes.post("/api/v1/prompts")
async def post_api_prompt(request: web.Request) -> web.Response | web.FileResponse: async def post_api_prompt(request: web.Request) -> web.Response | web.FileResponse:
@ -751,8 +750,8 @@ class PromptServer(ExecutorToClientProgress):
queue_size = self.prompt_queue.size() queue_size = self.prompt_queue.size()
queue_too_busy_size = PromptServer.get_too_busy_queue_size() queue_too_busy_size = PromptServer.get_too_busy_queue_size()
if queue_size > queue_too_busy_size: if queue_size > queue_too_busy_size:
return web.Response(status=429, return web.json_response(status=429,
reason=f"the queue has {queue_size} elements and {queue_too_busy_size} is the limit for this worker") reason=f"the queue has {queue_size} elements and {queue_too_busy_size} is the limit for this worker")
# read the request # read the request
prompt_dict: dict = {} prompt_dict: dict = {}
if content_type == 'application/json': if content_type == 'application/json':

View File

@ -2,10 +2,11 @@ import asyncio
import itertools import itertools
import logging import logging
import os import os
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor from concurrent.futures import ProcessPoolExecutor
from .extra_model_paths import load_extra_path_config
from .main_pre import args from .main_pre import args
from .extra_model_paths import load_extra_path_config
from ..distributed.executors import ContextVarExecutor
async def main(): async def main():
@ -43,7 +44,7 @@ async def main():
from ..distributed.distributed_prompt_worker import DistributedPromptWorker from ..distributed.distributed_prompt_worker import DistributedPromptWorker
async with DistributedPromptWorker(connection_uri=args.distributed_queue_connection_uri, async with DistributedPromptWorker(connection_uri=args.distributed_queue_connection_uri,
queue_name=args.distributed_queue_name, queue_name=args.distributed_queue_name,
executor=ThreadPoolExecutor(max_workers=1) if args.executor_factory == "ThreadPoolExecutor" else ProcessPoolExecutor(max_workers=1)): executor=ContextVarExecutor(max_workers=1) if args.executor_factory == "ThreadPoolExecutor" else ProcessPoolExecutor(max_workers=1)):
stop = asyncio.Event() stop = asyncio.Event()
try: try:
await stop.wait() await stop.wait()

View File

@ -8,6 +8,7 @@ from typing import Optional, Literal, Protocol, Union, NamedTuple, List
import PIL.Image import PIL.Image
from typing_extensions import NotRequired, TypedDict from typing_extensions import NotRequired, TypedDict
from .outputs_types import OutputsDict
from .queue_types import BinaryEventTypes from .queue_types import BinaryEventTypes
from ..cli_args_types import Configuration from ..cli_args_types import Configuration
from ..nodes.package_typing import InputTypeSpec from ..nodes.package_typing import InputTypeSpec
@ -205,8 +206,8 @@ class DuplicateNodeError(Exception):
class HistoryResultDict(TypedDict, total=True): class HistoryResultDict(TypedDict, total=True):
outputs: dict outputs: OutputsDict
meta: dict meta: OutputsDict
class DependencyCycleError(Exception): class DependencyCycleError(Exception):

View File

@ -1,15 +1,9 @@
import os
from typing import Literal, Optional
from pathlib import Path from pathlib import Path
from typing import Literal, Optional
from ..cmd import folder_paths from ..cmd import folder_paths
def _is_strictly_below_root(path: Path) -> bool:
resolved_path = path.resolve()
return ".." not in resolved_path.parts and resolved_path.is_absolute()
def file_output_path(filename: str, type: Literal["input", "output", "temp"] = "output", def file_output_path(filename: str, type: Literal["input", "output", "temp"] = "output",
subfolder: Optional[str] = None) -> str: subfolder: Optional[str] = None) -> str:
""" """
@ -22,22 +16,15 @@ def file_output_path(filename: str, type: Literal["input", "output", "temp"] = "
:return: :return:
""" """
filename, output_dir = folder_paths.annotated_filepath(str(filename)) filename, output_dir = folder_paths.annotated_filepath(str(filename))
if not _is_strictly_below_root(Path(filename)):
raise PermissionError("insecure")
if output_dir is None: if output_dir is None:
output_dir = folder_paths.get_directory_by_type(type) output_dir = folder_paths.get_directory_by_type(type)
if output_dir is None: if output_dir is None:
raise ValueError(f"no such output directory because invalid type specified (type={type})") raise ValueError(f"no such output directory because invalid type specified (type={type})")
if subfolder is not None and subfolder != "": output_dir = Path(output_dir)
full_output_dir = str(os.path.join(output_dir, subfolder)) subfolder = Path(subfolder or "")
if str(os.path.commonpath([os.path.abspath(full_output_dir), output_dir])) != str(output_dir): try:
raise PermissionError("insecure") relative_to = (output_dir / subfolder / filename).relative_to(output_dir)
output_dir = full_output_dir except ValueError:
filename = os.path.basename(filename) raise PermissionError(f"{output_dir / subfolder / filename} is not a subpath of {output_dir}")
else: return str((output_dir / relative_to).resolve(strict=True))
if str(os.path.commonpath([os.path.abspath(output_dir), os.path.join(output_dir, filename)])) != str(output_dir):
raise PermissionError("insecure")
file = os.path.join(output_dir, filename)
return file

View File

@ -1,9 +1,15 @@
from __future__ import annotations from __future__ import annotations
import dataclasses import dataclasses
import itertools
import os import os
import typing import typing
from typing import List, Set, Any, Iterator, Sequence, Dict, NamedTuple import weakref
from abc import ABC, abstractmethod
from typing import Any, NamedTuple, Optional, Iterable
from pathlib import Path
from .platform_path import construct_path
supported_pt_extensions = frozenset(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft', ".index.json"]) supported_pt_extensions = frozenset(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft', ".index.json"])
extension_mimetypes_cache = { extension_mimetypes_cache = {
@ -11,35 +17,267 @@ extension_mimetypes_cache = {
} }
@dataclasses.dataclass def do_add(collection: list | set, index: int | None, item: Any):
if isinstance(collection, list) and index == 0:
collection.insert(0, item)
elif isinstance(collection, set):
collection.add(item)
else:
assert isinstance(collection, list)
collection.append(item)
class FolderPathsTuple: class FolderPathsTuple:
folder_name: str def __init__(self, folder_name: str = None, paths: list[str] = None, supported_extensions: set[str] = None, parent: Optional[weakref.ReferenceType[FolderNames]] = None):
paths: List[str] = dataclasses.field(default_factory=list) paths = paths or []
supported_extensions: Set[str] = dataclasses.field(default_factory=lambda: set(supported_pt_extensions)) supported_extensions = supported_extensions or set(supported_pt_extensions)
self.folder_name = folder_name
self.parent = parent
self._paths = paths
self._supported_extensions = supported_extensions
@property
def supported_extensions(self) -> set[str] | SupportedExtensions:
if self.parent is not None and self.folder_name is not None:
return SupportedExtensions(self.folder_name, self.parent)
else:
return self._supported_extensions
@supported_extensions.setter
def supported_extensions(self, value: set[str] | SupportedExtensions):
self.supported_extensions.clear()
self.supported_extensions.update(value)
@property
def paths(self) -> list[str] | PathsList:
if self.parent is not None and self.folder_name is not None:
return PathsList(self.folder_name, self.parent)
else:
return self._paths
def __iter__(self) -> typing.Generator[typing.Iterable[str]]:
"""
allows this proxy to behave like a tuple everywhere it is used by the custom nodes
:return:
"""
yield self.paths
yield self.supported_extensions
def __getitem__(self, item: Any): def __getitem__(self, item: Any):
if item == 0: if item == 0:
return self.paths return self.paths
if item == 1: if item == 1:
return self.supported_extensions return self.supported_extensions
else:
raise RuntimeError("unsupported tuple index")
def __add__(self, other: "FolderPathsTuple"): raise RuntimeError("unsupported tuple index")
assert self.folder_name == other.folder_name
# todo: make sure the paths are actually unique, as this method intends
new_paths = list(frozenset(self.paths + other.paths))
new_supported_extensions = self.supported_extensions | other.supported_extensions
return FolderPathsTuple(self.folder_name, new_paths, new_supported_extensions)
def __iter__(self) -> Iterator[Sequence[str]]: def __iadd__(self, other: FolderPathsTuple):
yield self.paths for path in other.paths:
yield self.supported_extensions self.paths.append(path)
for ext in other.supported_extensions:
self.supported_extensions.add(ext)
@dataclasses.dataclass
class PathsList:
folder_name: str
parent: weakref.ReferenceType[FolderNames]
def __iter__(self) -> typing.Generator[str]:
p: FolderNames = self.parent()
for path in p.directory_paths(self.folder_name):
try:
yield str(path.resolve())
except (OSError, AttributeError):
yield str(path)
def __getitem__(self, item: int):
paths = [x for x in self]
return paths[item]
def append(self, path_str: str):
p: FolderNames = self.parent()
p.add_paths(self.folder_name, [path_str])
def insert(self, path_str: str, index: int):
p: FolderNames = self.parent()
p.add_paths(self.folder_name, [path_str], index=index)
@dataclasses.dataclass
class SupportedExtensions:
folder_name: str
parent: weakref.ReferenceType[FolderNames]
def _append_any(self, other):
if other is None:
return
p: FolderNames = self.parent()
if isinstance(other, str):
other = {other}
p.add_supported_extension(self.folder_name, *other)
def __iter__(self) -> typing.Generator[str]:
p: FolderNames = self.parent()
for ext in p.supported_extensions(self.folder_name):
yield ext
def clear(self):
p: FolderNames = self.parent()
p.remove_all_supported_extensions(self.folder_name)
__ior__ = _append_any
add = _append_any
update = _append_any
@dataclasses.dataclass
class ApplicationPaths:
output_directory: Path = dataclasses.field(default_factory=lambda: construct_path("output"))
temp_directory: Path = dataclasses.field(default_factory=lambda: construct_path("temp"))
input_directory: Path = dataclasses.field(default_factory=lambda: construct_path("input"))
user_directory: Path = dataclasses.field(default_factory=lambda: construct_path("user"))
def __iter__(self) -> typing.Generator[Path]:
yield self.output_directory
yield self.temp_directory
yield self.input_directory
yield self.user_directory
@dataclasses.dataclass
class AbstractPaths(ABC):
folder_names: list[str]
supported_extensions: set[str] = dataclasses.field(default_factory=lambda: set(supported_pt_extensions))
@abstractmethod
def directory_paths(self, base_paths: Iterable[Path]) -> typing.Generator[Path]:
"""Generate directory paths based on the given base paths."""
pass
@abstractmethod
def file_paths(self, base_paths: Iterable[Path], relative=False) -> typing.Generator[Path]:
"""Generate file paths based on the given base paths."""
pass
@abstractmethod
def has_folder_name(self, folder_name: str) -> bool:
"""Check if the given folder name is in folder_names."""
pass
@dataclasses.dataclass
class ModelPaths(AbstractPaths):
folder_name_base_path_subdir: Path = dataclasses.field(default_factory=lambda: construct_path("models"))
additional_relative_directory_paths: set[Path] = dataclasses.field(default_factory=set)
additional_absolute_directory_paths: set[str | Path] = dataclasses.field(default_factory=set)
folder_names_are_relative_directory_paths_too: bool = dataclasses.field(default_factory=lambda: True)
def directory_paths(self, base_paths: Iterable[Path]) -> typing.Generator[Path]:
yielded_so_far: set[Path] = set()
for base_path in base_paths:
if self.folder_names_are_relative_directory_paths_too:
for folder_name in self.folder_names:
resolved_default_folder_name_path = base_path / self.folder_name_base_path_subdir / folder_name
if not resolved_default_folder_name_path in yielded_so_far:
yield resolved_default_folder_name_path
yielded_so_far.add(resolved_default_folder_name_path)
for additional_relative_directory_path in self.additional_relative_directory_paths:
resolved_additional_relative_path = base_path / additional_relative_directory_path
if not resolved_additional_relative_path in yielded_so_far:
yield resolved_additional_relative_path
yielded_so_far.add(resolved_additional_relative_path)
# resolve all paths
yielded_so_far = {path.resolve() for path in yielded_so_far}
for additional_absolute_path in self.additional_absolute_directory_paths:
try:
resolved_absolute_path = additional_absolute_path.resolve()
except (OSError, AttributeError):
resolved_absolute_path = additional_absolute_path
if not resolved_absolute_path in yielded_so_far:
yield resolved_absolute_path
yielded_so_far.add(resolved_absolute_path)
def file_paths(self, base_paths: Iterable[Path], relative=False) -> typing.Generator[Path]:
for path in self.directory_paths(base_paths):
for dirpath, dirnames, filenames in os.walk(path, followlinks=True):
if '.git' in dirnames:
dirnames.remove('.git')
for filename in filenames:
if any(filename.endswith(ext) for ext in self.supported_extensions):
result_path = construct_path(dirpath) / filename
if relative:
yield result_path.relative_to(path)
else:
yield result_path
def has_folder_name(self, folder_name: str) -> bool:
return folder_name in self.folder_names
@dataclasses.dataclass
class FolderNames: class FolderNames:
application_paths: typing.Optional[ApplicationPaths] = dataclasses.field(default_factory=ApplicationPaths)
contents: list[AbstractPaths] = dataclasses.field(default_factory=list)
base_paths: list[Path] = dataclasses.field(default_factory=list)
def supported_extensions(self, folder_name: str) -> typing.Generator[str]:
for candidate in self.contents:
if candidate.has_folder_name(folder_name):
for supported_extensions in candidate.supported_extensions:
for supported_extension in supported_extensions:
yield supported_extension
def directory_paths(self, folder_name: str) -> typing.Generator[Path]:
for directory_path in itertools.chain.from_iterable([candidate.directory_paths(self.base_paths)
for candidate in self.contents if candidate.has_folder_name(folder_name)]):
yield directory_path
def file_paths(self, folder_name: str, relative=False) -> typing.Generator[Path]:
for file_path in itertools.chain.from_iterable([candidate.file_paths(self.base_paths, relative=relative)
for candidate in self.contents if candidate.has_folder_name(folder_name)]):
yield file_path
def first_existing_or_none(self, folder_name: str, relative_file_path: Path) -> Optional[Path]:
for directory_path in itertools.chain.from_iterable([candidate.directory_paths(self.base_paths)
for candidate in self.contents if candidate.has_folder_name(folder_name)]):
candidate_file_path: Path = construct_path(directory_path / relative_file_path)
try:
# todo: this should follow the symlink
if Path.exists(candidate_file_path):
return candidate_file_path
except OSError:
continue
return None
def add_supported_extension(self, folder_name: str, *supported_extensions: str | None):
if supported_extensions is None:
return
for candidate in self.contents:
if candidate.has_folder_name(folder_name):
candidate.supported_extensions.update(supported_extensions)
def remove_all_supported_extensions(self, folder_name: str):
for candidate in self.contents:
if candidate.has_folder_name(folder_name):
candidate.supported_extensions.clear()
def add(self, model_paths: AbstractPaths):
self.contents.append(model_paths)
def add_base_path(self, base_path: Path):
if base_path not in self.base_paths:
self.base_paths.append(base_path)
@staticmethod @staticmethod
def from_dict(folder_paths_dict: dict[str, tuple[typing.Sequence[str], Sequence[str]]] = None) -> FolderNames: def from_dict(folder_paths_dict: dict[str, tuple[typing.Iterable[str], Iterable[str]]] = None) -> FolderNames:
""" """
Turns a dictionary of Turns a dictionary of
{ {
@ -50,65 +288,179 @@ class FolderNames:
:param folder_paths_dict: A dictionary :param folder_paths_dict: A dictionary
:return: A FolderNames object :return: A FolderNames object
""" """
if folder_paths_dict is None: raise NotImplementedError()
return FolderNames(os.getcwd())
fn = FolderNames(os.getcwd()) def __getitem__(self, folder_name) -> FolderPathsTuple:
for folder_name, (paths, extensions) in folder_paths_dict.items(): if not isinstance(folder_name, str) or folder_name is None:
paths_tuple = FolderPathsTuple(folder_name=folder_name, paths=list(paths), supported_extensions=set(extensions))
if folder_name in fn:
fn[folder_name] += paths_tuple
else:
fn[folder_name] = paths_tuple
return fn
def __init__(self, default_new_folder_path: str):
self.contents: Dict[str, FolderPathsTuple] = dict()
self.default_new_folder_path = default_new_folder_path
def __getitem__(self, item) -> FolderPathsTuple:
if not isinstance(item, str):
raise RuntimeError("expected folder path") raise RuntimeError("expected folder path")
if item not in self.contents: # todo: it is probably never the intention to do this
default_path = os.path.join(self.default_new_folder_path, item) try:
os.makedirs(default_path, exist_ok=True) path = Path(folder_name)
self.contents[item] = FolderPathsTuple(item, paths=[default_path], supported_extensions=set()) if path.is_absolute():
return self.contents[item] folder_name = path.stem
except Exception:
pass
if not any(candidate.has_folder_name(folder_name) for candidate in self.contents):
self.add(ModelPaths(folder_names=[folder_name], folder_name_base_path_subdir=construct_path(), supported_extensions=set(), folder_names_are_relative_directory_paths_too=False))
return FolderPathsTuple(folder_name, parent=weakref.ref(self))
def __setitem__(self, key: str, value: FolderPathsTuple): def add_paths(self, folder_name: str, paths: list[Path | str], index: Optional[int] = None):
assert isinstance(key, str) """
if isinstance(value, tuple): Adds, but does not create, new model paths
:param folder_name:
:param paths:
:param index:
:return:
"""
for candidate in self.contents:
if candidate.has_folder_name(folder_name):
self._modify_model_paths(folder_name, paths, set(), candidate, index=index)
def _modify_model_paths(self, key: str, paths: Iterable[Path | str], supported_extensions: set[str], model_paths: AbstractPaths = None, index: Optional[int] = None) -> AbstractPaths:
model_paths = model_paths or ModelPaths([key],
supported_extensions=set(supported_extensions),
folder_names_are_relative_directory_paths_too=False)
if index is not None and index != 0:
raise ValueError(f"index was {index} but only 0 or None is supported")
for path in paths:
if isinstance(path, str):
path = construct_path(path)
# when given absolute paths, try to formulate them as relative paths anyway
if path.is_absolute():
for base_path in self.base_paths:
try:
relative_to_basepath = path.relative_to(base_path)
potential_folder_name = relative_to_basepath.stem
potential_subdir = relative_to_basepath.parent
# does the folder_name so far match the key?
# or have we not seen this folder before?
folder_name_not_already_in_contents = all(not candidate.has_folder_name(potential_folder_name) for candidate in self.contents)
if potential_folder_name == key or folder_name_not_already_in_contents:
# fix the subdir
model_paths.folder_name_base_path_subdir = potential_subdir
model_paths.folder_names_are_relative_directory_paths_too = True
if folder_name_not_already_in_contents:
do_add(model_paths.folder_names, index, potential_folder_name)
else:
# if the folder name doesn't match the key, check if we have ever seen a folder name that matches the key:
if model_paths.folder_names_are_relative_directory_paths_too:
if potential_subdir == model_paths.folder_name_base_path_subdir:
# then we want to add this to the folder name, because it's probably compatibility
do_add(model_paths.folder_names, index, potential_folder_name)
else:
# not this case
model_paths.folder_names_are_relative_directory_paths_too = False
if not model_paths.folder_names_are_relative_directory_paths_too:
model_paths.additional_relative_directory_paths.add(relative_to_basepath)
for resolve_folder_name in model_paths.folder_names:
model_paths.additional_relative_directory_paths.add(model_paths.folder_name_base_path_subdir / resolve_folder_name)
# since this was an absolute path that was a subdirectory of one of the base paths,
# we are done
break
except ValueError:
# this is not a subpath of the base path
pass
# if we got this far, none of the absolute paths were subdirectories of any base paths
# add it to our absolute paths
model_paths.additional_absolute_directory_paths.add(path)
else:
# since this is a relative path, peacefully add it to model_paths
potential_folder_name = path.stem
try:
relative_to_current_subdir = path.relative_to(model_paths.folder_name_base_path_subdir)
# if relative to the current subdir, we observe only one part, we're good to go
if len(relative_to_current_subdir.parts) == 1:
if potential_folder_name == key:
model_paths.folder_names_are_relative_directory_paths_too = True
else:
# if there already exists a folder_name by this name, do not add it, and switch to all relative paths
if any(candidate.has_folder_name(potential_folder_name) for candidate in self.contents):
model_paths.folder_names_are_relative_directory_paths_too = False
model_paths.additional_relative_directory_paths.add(path)
model_paths.folder_name_base_path_subdir = construct_path()
else:
do_add(model_paths.folder_names, index, potential_folder_name)
except ValueError:
# this means that the relative directory didn't contain the subdir so far
# something_not_models/key
if potential_folder_name == key:
model_paths.folder_name_base_path_subdir = path.parent
model_paths.folder_names_are_relative_directory_paths_too = True
else:
if any(candidate.has_folder_name(potential_folder_name) for candidate in self.contents):
model_paths.folder_names_are_relative_directory_paths_too = False
model_paths.additional_relative_directory_paths.add(path)
model_paths.folder_name_base_path_subdir = construct_path()
else:
do_add(model_paths.folder_names, index, potential_folder_name)
return model_paths
def __setitem__(self, key: str, value: tuple | FolderPathsTuple | AbstractPaths):
# remove all existing paths for this key
self.__delitem__(key)
if isinstance(value, AbstractPaths):
self.contents.append(value)
elif isinstance(value, (tuple, FolderPathsTuple)):
paths, supported_extensions = value paths, supported_extensions = value
value = FolderPathsTuple(key, paths, supported_extensions) # typical cases:
self.contents[key] = value # key="checkpoints", paths="C:/base_path/models/checkpoints"
# key="unets", paths="C:/base_path/models/unets", "C:/base_path/models/diffusion_models"
# ^ in this case, we will want folder_names to be both unets and diffusion_models
# key="custom_loader", paths="C:/base_path/models/checkpoints"
# if the paths are subdirectories of any basepath, use relative paths
paths: list[Path] = list(map(Path, paths))
self.contents.append(self._modify_model_paths(key, paths, supported_extensions))
def __len__(self): def __len__(self):
return len(self.contents) return len(self.contents)
def __iter__(self): def __iter__(self) -> typing.Generator[str]:
return iter(self.contents) for model_paths in self.contents:
for folder_name in model_paths.folder_names:
yield folder_name
def __delitem__(self, key): def __delitem__(self, key):
del self.contents[key] to_remove: list[AbstractPaths] = []
for model_paths in self.contents:
if model_paths.has_folder_name(key):
to_remove.append(model_paths)
def __contains__(self, item): for model_paths in to_remove:
return item in self.contents self.contents.remove(model_paths)
def __contains__(self, item: str):
return any(candidate.has_folder_name(item) for candidate in self.contents)
def items(self): def items(self):
return self.contents.items() items_view = {
folder_name: self[folder_name] for folder_name in self.keys()
}
return items_view.items()
def values(self): def values(self):
return self.contents.values() return [self[folder_name] for folder_name in self.keys()]
def keys(self): def keys(self):
return self.contents.keys() return [x for x in self]
def get(self, key, __default=None): def get(self, key, __default=None):
return self.contents.get(key, __default) for candidate in self.contents:
if candidate.has_folder_name(key):
return FolderPathsTuple(key, parent=weakref.ref(self))
if __default is not None:
raise ValueError("get with default is not supported")
class SaveImagePathResponse(NamedTuple): class SaveImagePathTuple(NamedTuple):
full_output_folder: str full_output_folder: str
filename: str filename: str
counter: int counter: int

View File

@ -1,22 +1,44 @@
import sys import sys
from functools import wraps
def create_module_properties():
properties = {}
patched_modules = set()
def module_property(func): def patch_module(module):
"""Decorator to turn module functions into properties. if module in patched_modules:
Function names must be prefixed with an underscore.""" return
module = sys.modules[func.__module__]
def base_getattr(name): def base_getattr(name):
raise AttributeError( raise AttributeError(f"module '{module.__name__}' has no attribute '{name}'")
f"module '{module.__name__}' has no attribute '{name}'")
old_getattr = getattr(module, '__getattr__', base_getattr) old_getattr = getattr(module, '__getattr__', base_getattr)
def new_getattr(name): def new_getattr(name):
if f'_{name}' == func.__name__: if name in properties:
return func() return properties[name]()
else: else:
return old_getattr(name) return old_getattr(name)
module.__getattr__ = new_getattr module.__getattr__ = new_getattr
return func patched_modules.add(module)
class ModuleProperties:
@staticmethod
def getter(func):
@wraps(func)
def wrapper():
return func()
name = func.__name__
if name.startswith('_'):
properties[name[1:]] = wrapper
else:
raise ValueError("Property function names must start with an underscore")
module = sys.modules[func.__module__]
patch_module(module)
return wrapper
return ModuleProperties()

View File

@ -0,0 +1,10 @@
from __future__ import annotations
import typing
# for nodes that return:
# { "ui" : { "some_field": Any, "other": Any }}
# the outputs dict will be
# (the node id)
# { "1": { "some_field": Any, "other": Any }}
OutputsDict = dict[str, dict[str, typing.Any]]

View File

@ -0,0 +1,13 @@
from __future__ import annotations
from pathlib import PurePosixPath, Path, PosixPath
def construct_path(*args) -> PurePosixPath | Path:
if len(args) > 0 and args[0] is not None and isinstance(args[0], str) and args[0].startswith("/"):
try:
return PosixPath(*args)
except Exception:
return PurePosixPath(*args)
else:
return Path(*args)

View File

@ -8,13 +8,15 @@ from typing import Tuple
from typing_extensions import NotRequired, TypedDict from typing_extensions import NotRequired, TypedDict
from .outputs_types import OutputsDict
QueueTuple = Tuple[float, str, dict, dict, list] QueueTuple = Tuple[float, str, dict, dict, list]
MAXIMUM_HISTORY_SIZE = 10000 MAXIMUM_HISTORY_SIZE = 10000
class TaskInvocation(NamedTuple): class TaskInvocation(NamedTuple):
item_id: int | str item_id: int | str
outputs: dict outputs: OutputsDict
status: Optional[ExecutionStatus] status: Optional[ExecutionStatus]

View File

@ -10,11 +10,12 @@ from aio_pika.patterns import JsonRPC
from aiohttp import web from aiohttp import web
from aiormq import AMQPConnectionError from aiormq import AMQPConnectionError
from .executors import ContextVarExecutor
from .distributed_progress import DistributedExecutorToClientProgress from .distributed_progress import DistributedExecutorToClientProgress
from .distributed_types import RpcRequest, RpcReply from .distributed_types import RpcRequest, RpcReply
from .process_pool_executor import ProcessPoolExecutor
from ..client.embedded_comfy_client import EmbeddedComfyClient from ..client.embedded_comfy_client import EmbeddedComfyClient
from ..cmd.main_pre import tracer from ..cmd.main_pre import tracer
from ..component_model.executor_types import Executor
from ..component_model.queue_types import ExecutionStatus from ..component_model.queue_types import ExecutionStatus
@ -28,7 +29,7 @@ class DistributedPromptWorker:
queue_name: str = "comfyui", queue_name: str = "comfyui",
health_check_port: int = 9090, health_check_port: int = 9090,
loop: Optional[AbstractEventLoop] = None, loop: Optional[AbstractEventLoop] = None,
executor: Optional[Executor] = None): executor: Optional[ContextVarExecutor | ProcessPoolExecutor] = None):
self._rpc = None self._rpc = None
self._channel = None self._channel = None
self._exit_stack = AsyncExitStack() self._exit_stack = AsyncExitStack()

View File

@ -0,0 +1,24 @@
import concurrent
import contextvars
import typing
from concurrent.futures import Future, ThreadPoolExecutor
from functools import partial
__version__ = '0.0.1'
from .process_pool_executor import ProcessPoolExecutor
class ContextVarExecutor(ThreadPoolExecutor):
def submit(self, fn: typing.Callable, *args, **kwargs) -> Future:
ctx = contextvars.copy_context() # type: contextvars.Context
return super().submit(partial(ctx.run, partial(fn, *args, **kwargs)))
class ContextVarProcessPoolExecutor(ProcessPoolExecutor):
def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future:
# TODO: serialize the "comfyui_execution_context"
pass

View File

@ -6,6 +6,7 @@ from dataclasses import dataclass, replace
from typing import Optional, Final from typing import Optional, Final
from .component_model.executor_types import ExecutorToClientProgress from .component_model.executor_types import ExecutorToClientProgress
from .component_model.folder_path_types import FolderNames
from .distributed.server_stub import ServerStub from .distributed.server_stub import ServerStub
_current_context: Final[ContextVar] = ContextVar("comfyui_execution_context") _current_context: Final[ContextVar] = ContextVar("comfyui_execution_context")
@ -14,23 +15,21 @@ _current_context: Final[ContextVar] = ContextVar("comfyui_execution_context")
@dataclass(frozen=True) @dataclass(frozen=True)
class ExecutionContext: class ExecutionContext:
server: ExecutorToClientProgress server: ExecutorToClientProgress
folder_names_and_paths: FolderNames
node_id: Optional[str] = None node_id: Optional[str] = None
task_id: Optional[str] = None task_id: Optional[str] = None
inference_mode: bool = True inference_mode: bool = True
_empty_execution_context: Final[ExecutionContext] = ExecutionContext(server=ServerStub()) _current_context.set(ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames()))
def current_execution_context() -> ExecutionContext: def current_execution_context() -> ExecutionContext:
try: return _current_context.get()
return _current_context.get()
except LookupError:
return _empty_execution_context
@contextmanager @contextmanager
def new_execution_context(ctx: ExecutionContext): def _new_execution_context(ctx: ExecutionContext):
token = _current_context.set(ctx) token = _current_context.set(ctx)
try: try:
yield ctx yield ctx
@ -39,8 +38,24 @@ def new_execution_context(ctx: ExecutionContext):
@contextmanager @contextmanager
def context_execute_node(node_id: str, prompt_id: str): def context_folder_names_and_paths(folder_names_and_paths: FolderNames):
current_ctx = current_execution_context() current_ctx = current_execution_context()
new_ctx = replace(current_ctx, node_id=node_id, task_id=prompt_id) new_ctx = replace(current_ctx, folder_names_and_paths=folder_names_and_paths)
with new_execution_context(new_ctx): with _new_execution_context(new_ctx):
yield new_ctx
@contextmanager
def context_execute_prompt(server: ExecutorToClientProgress, prompt_id: str, inference_mode: bool = True):
current_ctx = current_execution_context()
new_ctx = replace(current_ctx, server=server, task_id=prompt_id, inference_mode=inference_mode)
with _new_execution_context(new_ctx):
yield new_ctx
@contextmanager
def context_execute_node(node_id: str):
current_ctx = current_execution_context()
new_ctx = replace(current_ctx, node_id=node_id)
with _new_execution_context(new_ctx):
yield new_ctx yield new_ctx

View File

@ -198,10 +198,6 @@ Visit the repository, accept the terms, and then do one of the following:
- Login to Hugging Face in your terminal using `huggingface-cli login` - Login to Hugging Face in your terminal using `huggingface-cli login`
""") """)
raise exc_info raise exc_info
finally:
# a path was found for any reason, so we should invalidate the cache
if path is not None:
folder_paths.invalidate_cache(folder_name)
if path is None: if path is None:
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found, and no download candidates matched for the filename.") raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found, and no download candidates matched for the filename.")
return path return path
@ -504,7 +500,6 @@ def add_known_models(folder_name: str, known_models: Optional[List[Downloadable]
pre_existing = frozenset(known_models) pre_existing = frozenset(known_models)
known_models.extend([model for model in models if model not in pre_existing]) known_models.extend([model for model in models if model not in pre_existing])
folder_paths.invalidate_cache(folder_name)
return known_models return known_models

View File

@ -2,7 +2,6 @@ import torch
import os import os
import json import json
import hashlib
import math import math
import random import random
import logging import logging
@ -1654,7 +1653,7 @@ class LoadImageMask:
input_dir = folder_paths.get_input_directory() input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
return {"required": return {"required":
{"image": (sorted(files), {"image_upload": True}), {"image": (natsorted(files), {"image_upload": True}),
"channel": (s._color_channels, ), } "channel": (s._color_channels, ), }
} }

View File

@ -29,9 +29,9 @@ import sys
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from pickle import UnpicklingError
from typing import Optional, Any from typing import Optional, Any
import accelerate
import numpy as np import numpy as np
import safetensors.torch import safetensors.torch
import torch import torch
@ -65,7 +65,7 @@ def load_torch_file(ckpt: str, safe_load=False, device=None):
if ckpt is None: if ckpt is None:
raise FileNotFoundError("the checkpoint was not found") raise FileNotFoundError("the checkpoint was not found")
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
sd = safetensors.torch.load_file(ckpt, device=device.type) sd = safetensors.torch.load_file(Path(ckpt).resolve(strict=True), device=device.type)
elif ckpt.lower().endswith("index.json"): elif ckpt.lower().endswith("index.json"):
# from accelerate # from accelerate
index_filename = ckpt index_filename = ckpt
@ -81,20 +81,29 @@ def load_torch_file(ckpt: str, safe_load=False, device=None):
for checkpoint_file in checkpoint_files: for checkpoint_file in checkpoint_files:
sd.update(safetensors.torch.load_file(str(checkpoint_file), device=device.type)) sd.update(safetensors.torch.load_file(str(checkpoint_file), device=device.type))
else: else:
if safe_load: try:
if not 'weights_only' in torch.load.__code__.co_varnames: if safe_load:
logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") if not 'weights_only' in torch.load.__code__.co_varnames:
safe_load = False logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
if safe_load: safe_load = False
pl_sd = torch.load(ckpt, map_location=device, weights_only=True) if safe_load:
else: pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle) else:
if "global_step" in pl_sd: pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle)
logging.debug(f"Global Step: {pl_sd['global_step']}") if "global_step" in pl_sd:
if "state_dict" in pl_sd: logging.debug(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"] if "state_dict" in pl_sd:
else: sd = pl_sd["state_dict"]
sd = pl_sd else:
sd = pl_sd
except UnpicklingError as exc_info:
try:
# wrong extension is most likely, try to load as safetensors anyway
sd = safetensors.torch.load_file(Path(ckpt).resolve(strict=True), device=device.type)
return sd
except Exception:
exc_info.add_note(f"The checkpoint at {ckpt} could not be loaded as a safetensor nor a torch checkpoint. The file at the path is corrupted or unexpected. Try deleting it and downloading it again")
raise exc_info
return sd return sd

View File

@ -13,7 +13,7 @@ from transformers.models.nllb.tokenization_nllb import \
FAIRSEQ_LANGUAGE_CODES as tokenization_nllb_FAIRSEQ_LANGUAGE_CODES FAIRSEQ_LANGUAGE_CODES as tokenization_nllb_FAIRSEQ_LANGUAGE_CODES
from comfy.cmd import folder_paths from comfy.cmd import folder_paths
from comfy.component_model.folder_path_types import SaveImagePathResponse from comfy.component_model.folder_path_types import SaveImagePathTuple
from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
from comfy.language.language_types import GENERATION_KWARGS_TYPE, GENERATION_KWARGS_TYPE_NAME, TOKENS_TYPE, \ from comfy.language.language_types import GENERATION_KWARGS_TYPE, GENERATION_KWARGS_TYPE_NAME, TOKENS_TYPE, \
TOKENS_TYPE_NAME, LanguageModel TOKENS_TYPE_NAME, LanguageModel
@ -397,7 +397,7 @@ class SaveString(CustomNode):
OUTPUT_NODE = True OUTPUT_NODE = True
RETURN_TYPES = () RETURN_TYPES = ()
def get_save_path(self, filename_prefix) -> SaveImagePathResponse: def get_save_path(self, filename_prefix) -> SaveImagePathTuple:
return folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory(), 0, 0) return folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory(), 0, 0)
def execute(self, value: str | list[str], filename_prefix: str, extension: str = ".json"): def execute(self, value: str | list[str], filename_prefix: str, extension: str = ".json"):

View File

@ -1,21 +0,0 @@
import torch
class TorchCompileModel:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
EXPERIMENTAL = True
def patch(self, model):
m = model.clone()
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model")))
return (m, )
NODE_CLASS_MAPPINGS = {
"TorchCompileModel": TorchCompileModel,
}

10
main.py
View File

@ -2,13 +2,15 @@ import asyncio
import warnings import warnings
from pathlib import Path from pathlib import Path
if __name__ == "__main__": from comfy.component_model.folder_path_types import FolderNames
from comfy.cmd.folder_paths_pre import set_base_path
if __name__ == "__main__":
warnings.warn("main.py is deprecated. Start comfyui by installing the package through the instructions in the README, not by cloning the repository.", DeprecationWarning) warnings.warn("main.py is deprecated. Start comfyui by installing the package through the instructions in the README, not by cloning the repository.", DeprecationWarning)
this_file_parent_dir = Path(__file__).parent this_file_parent_dir = Path(__file__).parent
set_base_path(str(this_file_parent_dir))
from comfy.cmd.main import main from comfy.cmd.main import main
from comfy.cmd.folder_paths import folder_names_and_paths # type: FolderNames
fn: FolderNames = folder_names_and_paths
fn.base_paths.clear()
fn.base_paths.append(this_file_parent_dir)
asyncio.run(main(from_script_dir=this_file_parent_dir)) asyncio.run(main(from_script_dir=this_file_parent_dir))

View File

@ -1,6 +1,8 @@
import asyncio import asyncio
import logging
logging.basicConfig(level=logging.ERROR)
import uuid import uuid
from concurrent.futures import ThreadPoolExecutor
from typing import Callable from typing import Callable
import jwt import jwt
@ -15,10 +17,12 @@ from comfy.component_model.executor_types import Executor
from comfy.component_model.make_mutable import make_mutable from comfy.component_model.make_mutable import make_mutable
from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation, NamedQueueTuple, ExecutionStatus from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation, NamedQueueTuple, ExecutionStatus
from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker
from comfy.distributed.executors import ContextVarExecutor
from comfy.distributed.process_pool_executor import ProcessPoolExecutor from comfy.distributed.process_pool_executor import ProcessPoolExecutor
from comfy.distributed.server_stub import ServerStub from comfy.distributed.server_stub import ServerStub
def create_test_prompt() -> QueueItem: def create_test_prompt() -> QueueItem:
from comfy.cmd.execution import validate_prompt from comfy.cmd.execution import validate_prompt
@ -37,8 +41,11 @@ async def test_sign_jwt_auth_none():
assert user_token["sub"] == client_id assert user_token["sub"] == client_id
_executor_factories: tuple[Executor] = (ContextVarExecutor,)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("executor_factory", (ThreadPoolExecutor, ProcessPoolExecutor,)) @pytest.mark.parametrize("executor_factory", _executor_factories)
async def test_basic_queue_worker(executor_factory: Callable[..., Executor]) -> None: async def test_basic_queue_worker(executor_factory: Callable[..., Executor]) -> None:
with RabbitMqContainer("rabbitmq:latest") as rabbitmq: with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params() params = rabbitmq.get_connection_params()
@ -72,7 +79,7 @@ async def test_distributed_prompt_queues_same_process():
frontend.put(test_prompt) frontend.put(test_prompt)
# start a worker thread # start a worker thread
thread_pool = ThreadPoolExecutor(max_workers=1) thread_pool = ContextVarExecutor(max_workers=1)
async def in_thread(): async def in_thread():
incoming, incoming_prompt_id = worker.get() incoming, incoming_prompt_id = worker.get()
@ -127,7 +134,7 @@ async def check_health(url: str, max_retries: int = 5, retry_delay: float = 1.0)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("executor_factory", (ThreadPoolExecutor, ProcessPoolExecutor,)) @pytest.mark.parametrize("executor_factory", _executor_factories)
async def test_basic_queue_worker_with_health_check(executor_factory): async def test_basic_queue_worker_with_health_check(executor_factory):
with RabbitMqContainer("rabbitmq:latest") as rabbitmq: with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params() params = rabbitmq.get_connection_params()

View File

@ -1,3 +1,4 @@
import logging
import uuid import uuid
from contextvars import ContextVar from contextvars import ContextVar
from typing import Dict, Optional from typing import Dict, Optional
@ -73,7 +74,8 @@ class ComfyClient:
prompt_id = str(uuid.uuid4()) prompt_id = str(uuid.uuid4())
try: try:
outputs = await self.embedded_client.queue_prompt(graph.finalize(), prompt_id=prompt_id) outputs = await self.embedded_client.queue_prompt(graph.finalize(), prompt_id=prompt_id)
except (RuntimeError, DependencyCycleError): except (RuntimeError, DependencyCycleError) as exc_info:
logging.warning("error when queueing prompt", exc_info=exc_info)
outputs = {} outputs = {}
result = RunResult(prompt_id=prompt_id) result = RunResult(prompt_id=prompt_id)
result.outputs = outputs result.outputs = outputs

View File

@ -2,11 +2,13 @@
# TODO(yoland): clean up this after I get back down # TODO(yoland): clean up this after I get back down
import os import os
import tempfile import tempfile
from unittest.mock import patch from pathlib import Path
import pytest import pytest
from comfy.cmd import folder_paths from comfy.cmd import folder_paths
from comfy.component_model.folder_path_types import FolderNames, ModelPaths
from comfy.execution_context import context_folder_names_and_paths
@pytest.fixture @pytest.fixture
@ -40,16 +42,6 @@ def test_add_model_folder_path():
assert "/test/path" in folder_paths.get_folder_paths("test_folder") assert "/test/path" in folder_paths.get_folder_paths("test_folder")
def test_recursive_search(temp_dir):
os.makedirs(os.path.join(temp_dir, "subdir"))
open(os.path.join(temp_dir, "file1.txt"), "w").close()
open(os.path.join(temp_dir, "subdir", "file2.txt"), "w").close()
files, dirs = folder_paths.recursive_search(temp_dir)
assert set(files) == {"file1.txt", os.path.join("subdir", "file2.txt")}
assert len(dirs) == 2 # temp_dir and subdir
def test_filter_files_extensions(): def test_filter_files_extensions():
files = ["file1.txt", "file2.jpg", "file3.png", "file4.txt"] files = ["file1.txt", "file2.jpg", "file3.png", "file4.txt"]
assert folder_paths.filter_files_extensions(files, [".txt"]) == ["file1.txt", "file4.txt"] assert folder_paths.filter_files_extensions(files, [".txt"]) == ["file1.txt", "file4.txt"]
@ -57,16 +49,24 @@ def test_filter_files_extensions():
assert folder_paths.filter_files_extensions(files, []) == files assert folder_paths.filter_files_extensions(files, []) == files
@patch("folder_paths.recursive_search") def test_get_filename_list(temp_dir):
@patch("folder_paths.folder_names_and_paths") base_path = Path(temp_dir)
def test_get_filename_list(mock_folder_names_and_paths, mock_recursive_search): fn = FolderNames(base_paths=[base_path])
mock_folder_names_and_paths.__getitem__.return_value = (["/test/path"], {".txt"}) rel_path = Path("test/path")
mock_recursive_search.return_value = (["file1.txt", "file2.jpg"], {}) fn.add(ModelPaths(["test_folder"], additional_relative_directory_paths={rel_path}, supported_extensions={".txt"}))
assert folder_paths.get_filename_list("test_folder") == ["file1.txt"] dir_path = base_path / rel_path
Path.mkdir(dir_path, parents=True, exist_ok=True)
files = ["file1.txt", "file2.jpg"]
for file in files:
Path.touch(dir_path / file, exist_ok=True)
with context_folder_names_and_paths(fn):
assert folder_paths.get_filename_list("test_folder") == ["file1.txt"]
def test_get_save_image_path(temp_dir): def test_get_save_image_path(temp_dir):
with patch("folder_paths.output_directory", temp_dir): with context_folder_names_and_paths(FolderNames(base_paths=[Path(temp_dir)])):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("test", temp_dir, 100, 100) full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("test", temp_dir, 100, 100)
assert os.path.samefile(full_output_folder, temp_dir) assert os.path.samefile(full_output_folder, temp_dir)
assert filename == "test" assert filename == "test"