--base-paths argument adds additional paths to search for models/checkpoints, models/loras, etc. directories, including directories specified in this pattern by custom nodes

This commit is contained in:
doctorpangloss 2024-10-28 19:04:14 -07:00
parent 89d07f3adf
commit 4a13766d14
30 changed files with 932 additions and 525 deletions

101
README.md
View File

@ -722,19 +722,31 @@ You can pass additional extra model path configurations with one or more copies
### Command Line Arguments
```
usage: comfyui.exe [-h] [-c CONFIG_FILE] [--write-out-config-file CONFIG_OUTPUT_PATH] [-w CWD] [-H [IP]] [--port PORT] [--enable-cors-header [ORIGIN]] [--max-upload-size MAX_UPLOAD_SIZE]
[--extra-model-paths-config PATH [PATH ...]] [--output-directory OUTPUT_DIRECTORY] [--temp-directory TEMP_DIRECTORY] [--input-directory INPUT_DIRECTORY] [--auto-launch]
[--disable-auto-launch] [--cuda-device DEVICE_ID] [--cuda-malloc | --disable-cuda-malloc] [--force-fp32 | --force-fp16 | --force-bf16]
usage: comfyui.exe [-h] [-c CONFIG_FILE] [--write-out-config-file CONFIG_OUTPUT_PATH] [-w CWD] [--base-paths BASE_PATHS [BASE_PATHS ...]] [-H [IP]] [--port PORT]
[--enable-cors-header [ORIGIN]] [--max-upload-size MAX_UPLOAD_SIZE] [--extra-model-paths-config PATH [PATH ...]]
[--output-directory OUTPUT_DIRECTORY] [--temp-directory TEMP_DIRECTORY] [--input-directory INPUT_DIRECTORY] [--auto-launch] [--disable-auto-launch]
[--cuda-device DEVICE_ID] [--cuda-malloc | --disable-cuda-malloc] [--force-fp32 | --force-fp16 | --force-bf16]
[--bf16-unet | --fp16-unet | --fp8_e4m3fn-unet | --fp8_e5m2-unet] [--fp16-vae | --fp32-vae | --bf16-vae] [--cpu-vae]
[--fp8_e4m3fn-text-enc | --fp8_e5m2-text-enc | --fp16-text-enc | --fp32-text-enc] [--directml [DIRECTML_DEVICE]] [--disable-ipex-optimize]
[--preview-method [none,auto,latent2rgb,taesd]] [--use-split-cross-attention | --use-quad-cross-attention | --use-pytorch-cross-attention] [--disable-xformers]
[--force-upcast-attention | --dont-upcast-attention] [--gpu-only | --highvram | --normalvram | --lowvram | --novram | --cpu] [--disable-smart-memory] [--deterministic]
[--dont-print-server] [--quick-test-for-ci] [--windows-standalone-build] [--disable-metadata] [--multi-user] [--create-directories]
[--plausible-analytics-base-url PLAUSIBLE_ANALYTICS_BASE_URL] [--plausible-analytics-domain PLAUSIBLE_ANALYTICS_DOMAIN] [--analytics-use-identity-provider]
[--distributed-queue-connection-uri DISTRIBUTED_QUEUE_CONNECTION_URI] [--distributed-queue-worker] [--distributed-queue-frontend] [--distributed-queue-name DISTRIBUTED_QUEUE_NAME]
[--external-address EXTERNAL_ADDRESS] [--verbose] [--disable-known-models] [--max-queue-size MAX_QUEUE_SIZE] [--otel-service-name OTEL_SERVICE_NAME]
[--otel-service-version OTEL_SERVICE_VERSION] [--otel-exporter-otlp-endpoint OTEL_EXPORTER_OTLP_ENDPOINT]
[--preview-method [none,auto,latent2rgb,taesd]] [--preview-size PREVIEW_SIZE] [--cache-lru CACHE_LRU]
[--use-split-cross-attention | --use-quad-cross-attention | --use-pytorch-cross-attention] [--disable-xformers] [--disable-flash-attn]
[--disable-sage-attention] [--force-upcast-attention | --dont-upcast-attention]
[--gpu-only | --highvram | --normalvram | --lowvram | --novram | --cpu] [--reserve-vram RESERVE_VRAM]
[--default-hashing-function {md5,sha1,sha256,sha512}] [--disable-smart-memory] [--deterministic] [--fast] [--dont-print-server]
[--quick-test-for-ci] [--windows-standalone-build] [--disable-metadata] [--disable-all-custom-nodes] [--multi-user] [--create-directories]
[--plausible-analytics-base-url PLAUSIBLE_ANALYTICS_BASE_URL] [--plausible-analytics-domain PLAUSIBLE_ANALYTICS_DOMAIN]
[--analytics-use-identity-provider] [--distributed-queue-connection-uri DISTRIBUTED_QUEUE_CONNECTION_URI] [--distributed-queue-worker]
[--distributed-queue-frontend] [--distributed-queue-name DISTRIBUTED_QUEUE_NAME] [--external-address EXTERNAL_ADDRESS]
[--logging-level {DEBUG,INFO,WARNING,ERROR,CRITICAL}] [--disable-known-models] [--max-queue-size MAX_QUEUE_SIZE]
[--otel-service-name OTEL_SERVICE_NAME] [--otel-service-version OTEL_SERVICE_VERSION] [--otel-exporter-otlp-endpoint OTEL_EXPORTER_OTLP_ENDPOINT]
[--force-channels-last] [--force-hf-local-dir-mode] [--front-end-version FRONT_END_VERSION] [--front-end-root FRONT_END_ROOT]
[--executor-factory EXECUTOR_FACTORY] [--openai-api-key OPENAI_API_KEY] [--user-directory USER_DIRECTORY] [--blip-model-url BLIP_MODEL_URL]
[--blip-model-vqa-url BLIP_MODEL_VQA_URL] [--sam-model-vith-url SAM_MODEL_VITH_URL] [--sam-model-vitl-url SAM_MODEL_VITL_URL]
[--sam-model-vitb-url SAM_MODEL_VITB_URL] [--history-display-limit HISTORY_DISPLAY_LIMIT] [--ffmpeg-bin-path FFMPEG_BIN_PATH]
[--ffmpeg-extra-codecs FFMPEG_EXTRA_CODECS] [--wildcards-path WILDCARDS_PATH] [--wildcard-api WILDCARD_API] [--photoprism-host PHOTOPRISM_HOST]
[--immich-host IMMICH_HOST] [--ideogram-session-cookie IDEOGRAM_SESSION_COOKIE] [--annotator-ckpts-path ANNOTATOR_CKPTS_PATH] [--use-symlinks]
[--ort-providers ORT_PROVIDERS] [--vfi-ops-backend VFI_OPS_BACKEND] [--dependency-version DEPENDENCY_VERSION] [--mmdet-skip] [--sam-editor-cpu]
[--sam-editor-model SAM_EDITOR_MODEL] [--custom-wildcards CUSTOM_WILDCARDS] [--disable-gpu-opencv]
options:
-h, --help show this help message and exit
@ -742,10 +754,14 @@ options:
config file path
--write-out-config-file CONFIG_OUTPUT_PATH
takes the current command line args and writes them out to a config file at the given path, then exits
-w CWD, --cwd CWD Specify the working directory. If not set, this is the current working directory. models/, input/, output/ and other directories will be located here by default. [env var:
COMFYUI_CWD]
-w CWD, --cwd CWD Specify the working directory. If not set, this is the current working directory. models/, input/, output/ and other directories will be
located here by default. [env var: COMFYUI_CWD]
--base-paths BASE_PATHS [BASE_PATHS ...]
Additional base paths for custom nodes, models and inputs. [env var: COMFYUI_BASE_PATHS]
-H [IP], --listen [IP]
Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all) [env var: COMFYUI_LISTEN]
Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like:
127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6) [env var:
COMFYUI_LISTEN]
--port PORT Set the listen port. [env var: COMFYUI_PORT]
--enable-cors-header [ORIGIN]
Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'. [env var: COMFYUI_ENABLE_CORS_HEADER]
@ -789,6 +805,10 @@ options:
Disables ipex.optimize when loading models with Intel GPUs. [env var: COMFYUI_DISABLE_IPEX_OPTIMIZE]
--preview-method [none,auto,latent2rgb,taesd]
Default preview method for sampler nodes. [env var: COMFYUI_PREVIEW_METHOD]
--preview-size PREVIEW_SIZE
Sets the maximum preview size for sampler nodes. [env var: COMFYUI_PREVIEW_SIZE]
--cache-lru CACHE_LRU
Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM. [env var: COMFYUI_CACHE_LRU]
--use-split-cross-attention
Use the split cross attention optimization. Ignored when xformers is used. [env var: COMFYUI_USE_SPLIT_CROSS_ATTENTION]
--use-quad-cross-attention
@ -796,6 +816,9 @@ options:
--use-pytorch-cross-attention
Use the new pytorch 2.0 cross attention function. [env var: COMFYUI_USE_PYTORCH_CROSS_ATTENTION]
--disable-xformers Disable xformers. [env var: COMFYUI_DISABLE_XFORMERS]
--disable-flash-attn Disable Flash Attention [env var: COMFYUI_DISABLE_FLASH_ATTN]
--disable-sage-attention
Disable Sage Attention [env var: COMFYUI_DISABLE_SAGE_ATTENTION]
--force-upcast-attention
Force enable attention upcasting, please report if it fixes black images. [env var: COMFYUI_FORCE_UPCAST_ATTENTION]
--dont-upcast-attention
@ -806,15 +829,25 @@ options:
--lowvram Split the unet in parts to use less vram. [env var: COMFYUI_LOWVRAM]
--novram When lowvram isn't enough. [env var: COMFYUI_NOVRAM]
--cpu To use the CPU for everything (slow). [env var: COMFYUI_CPU]
--reserve-vram RESERVE_VRAM
Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.
[env var: COMFYUI_RESERVE_VRAM]
--default-hashing-function {md5,sha1,sha256,sha512}
Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256. [env var:
COMFYUI_DEFAULT_HASHING_FUNCTION]
--disable-smart-memory
Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can. [env var: COMFYUI_DISABLE_SMART_MEMORY]
--deterministic Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases. [env var: COMFYUI_DETERMINISTIC]
--deterministic Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases. [env var:
COMFYUI_DETERMINISTIC]
--fast Enable some untested and potentially quality deteriorating optimizations. [env var: COMFYUI_FAST]
--dont-print-server Don't print server output. [env var: COMFYUI_DONT_PRINT_SERVER]
--quick-test-for-ci Quick test for CI. Raises an error if nodes cannot be imported, [env var: COMFYUI_QUICK_TEST_FOR_CI]
--windows-standalone-build
Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup). [env var:
COMFYUI_WINDOWS_STANDALONE_BUILD]
Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening
the page on startup). [env var: COMFYUI_WINDOWS_STANDALONE_BUILD]
--disable-metadata Disable saving prompt metadata in files. [env var: COMFYUI_DISABLE_METADATA]
--disable-all-custom-nodes
Disable loading all custom nodes. [env var: COMFYUI_DISABLE_ALL_CUSTOM_NODES]
--multi-user Enables per-user storage. [env var: COMFYUI_MULTI_USER]
--create-directories Creates the default models/, input/, output/ and temp/ directories, then exits. [env var: COMFYUI_CREATE_DIRECTORIES]
--plausible-analytics-base-url PLAUSIBLE_ANALYTICS_BASE_URL
@ -824,18 +857,19 @@ options:
--analytics-use-identity-provider
Uses platform identifiers for unique visitor analytics. [env var: COMFYUI_ANALYTICS_USE_IDENTITY_PROVIDER]
--distributed-queue-connection-uri DISTRIBUTED_QUEUE_CONNECTION_URI
EXAMPLE: "amqp://guest:guest@127.0.0.1" - Servers and clients will connect to this AMPQ URL to form a distributed queue and exchange prompt execution requests and progress
updates. [env var: COMFYUI_DISTRIBUTED_QUEUE_CONNECTION_URI]
EXAMPLE: "amqp://guest:guest@127.0.0.1" - Servers and clients will connect to this AMPQ URL to form a distributed queue and exchange prompt
execution requests and progress updates. [env var: COMFYUI_DISTRIBUTED_QUEUE_CONNECTION_URI]
--distributed-queue-worker
Workers will pull requests off the AMQP URL. [env var: COMFYUI_DISTRIBUTED_QUEUE_WORKER]
--distributed-queue-frontend
Frontends will start the web UI and connect to the provided AMQP URL to submit prompts. [env var: COMFYUI_DISTRIBUTED_QUEUE_FRONTEND]
--distributed-queue-name DISTRIBUTED_QUEUE_NAME
This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the
user ID [env var: COMFYUI_DISTRIBUTED_QUEUE_NAME]
This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue
name, followed by a '.', then the user ID [env var: COMFYUI_DISTRIBUTED_QUEUE_NAME]
--external-address EXTERNAL_ADDRESS
Specifies a base URL for external addresses reported by the API, such as for image paths. [env var: COMFYUI_EXTERNAL_ADDRESS]
--verbose Enables more debug prints. [env var: COMFYUI_VERBOSE]
--logging-level {DEBUG,INFO,WARNING,ERROR,CRITICAL}
Set the logging level [env var: COMFYUI_LOGGING_LEVEL]
--disable-known-models
Disables automatic downloads of known models and prevents them from appearing in the UI. [env var: COMFYUI_DISABLE_KNOWN_MODELS]
--max-queue-size MAX_QUEUE_SIZE
@ -845,8 +879,27 @@ options:
--otel-service-version OTEL_SERVICE_VERSION
The version of the service or application that is generating telemetry data. [env var: OTEL_SERVICE_VERSION]
--otel-exporter-otlp-endpoint OTEL_EXPORTER_OTLP_ENDPOINT
A base endpoint URL for any signal type, with an optionally-specified port number. Helpful for when you're sending more than one signal to the same endpoint and want one
environment variable to control the endpoint. [env var: OTEL_EXPORTER_OTLP_ENDPOINT]
A base endpoint URL for any signal type, with an optionally-specified port number. Helpful for when you're sending more than one signal to the
same endpoint and want one environment variable to control the endpoint. [env var: OTEL_EXPORTER_OTLP_ENDPOINT]
--force-channels-last
Force channels last format when inferencing the models. [env var: COMFYUI_FORCE_CHANNELS_LAST]
--force-hf-local-dir-mode
Download repos from huggingface.co to the models/huggingface directory with the "local_dir" argument instead of models/huggingface_cache with
the "cache_dir" argument, recreating the traditional file structure. [env var: COMFYUI_FORCE_HF_LOCAL_DIR_MODE]
--front-end-version FRONT_END_VERSION
Specifies the version of the frontend to be used. This command needs internet connectivity to query and download available frontend
implementations from GitHub releases. The version string should be in the format of: [repoOwner]/[repoName]@[version] where version is one of:
"latest" or a valid version number (e.g. "1.0.0") [env var: COMFYUI_FRONT_END_VERSION]
--front-end-root FRONT_END_ROOT
The local filesystem path to the directory where the frontend is located. Overrides --front-end-version. [env var: COMFYUI_FRONT_END_ROOT]
--executor-factory EXECUTOR_FACTORY
When running ComfyUI as a distributed worker, this specifies the kind of executor that should be used to run the actual ComfyUI workflow
worker. A ThreadPoolExecutor is the default. A ProcessPoolExecutor results in better memory management, since the process will be closed and
large, contiguous blocks of CUDA memory can be freed. [env var: COMFYUI_EXECUTOR_FACTORY]
--openai-api-key OPENAI_API_KEY
Configures the OpenAI API Key for the OpenAI nodes [env var: OPENAI_API_KEY]
--user-directory USER_DIRECTORY
Set the ComfyUI user directory with an absolute path. [env var: COMFYUI_USER_DIRECTORY]
Args that start with '--' can also be set in a config file (config.yaml or config.json or specified via -c). Config file syntax allows: key=value, flag=true, stuff=[a,b,c] (for details, see syntax at
https://goo.gl/R74nmi). In general, command-line values override environment variables which override config file values which override defaults.

View File

@ -28,6 +28,7 @@ def _create_parser() -> EnhancedConfigArgParser:
parser.add_argument('-w', "--cwd", type=str, default=None,
help="Specify the working directory. If not set, this is the current working directory. models/, input/, output/ and other directories will be located here by default.")
parser.add_argument("--base-paths", type=str, nargs='+', default=[], help="Additional base paths for custom nodes, models and inputs.")
parser.add_argument('-H', "--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::",
help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
@ -161,7 +162,7 @@ def _create_parser() -> EnhancedConfigArgParser:
help="This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID")
parser.add_argument("--external-address", required=False,
help="Specifies a base URL for external addresses reported by the API, such as for image paths.")
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
parser.add_argument("--logging-level", type=str, default='INFO', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
parser.add_argument("--disable-known-models", action="store_true", help="Disables automatic downloads of known models and prevents them from appearing in the UI.")
parser.add_argument("--max-queue-size", type=int, default=65536, help="The API will reject prompt requests if the queue's size exceeds this value.")
# tracing
@ -251,11 +252,6 @@ def _parse_args(parser: Optional[argparse.ArgumentParser] = None, args_parsing:
if args.disable_auto_launch:
args.auto_launch = False
logging_level = logging.INFO
if args.verbose:
logging_level = logging.DEBUG
logging.basicConfig(format="%(message)s", level=logging_level)
configuration_obj = Configuration(**vars(args))
configuration_obj.config_files = config_files
assert all(isinstance(config_file, str) for config_file in config_files)

View File

@ -1,6 +1,8 @@
from __future__ import annotations
import collections
import enum
from pathlib import Path
from typing import Optional, List, Callable, Any, Union, Mapping, NamedTuple
import configargparse
@ -36,15 +38,16 @@ class Configuration(dict):
Attributes:
config_files (Optional[List[str]]): Path to the configuration file(s) that were set in the arguments.
cwd (Optional[str]): Working directory. Defaults to the current directory.
cwd (Optional[str]): Working directory. Defaults to the current directory. This is always treated as a base path for model files, and it will be the place where model files are downloaded.
base_paths (Optional[list[str]]): Additional base paths for custom nodes, models and inputs.
listen (str): IP address to listen on. Defaults to "127.0.0.1".
port (int): Port number for the server to listen on. Defaults to 8188.
enable_cors_header (Optional[str]): Enables CORS with the specified origin.
max_upload_size (float): Maximum upload size in MB. Defaults to 100.
extra_model_paths_config (Optional[List[str]]): Extra model paths configuration files.
output_directory (Optional[str]): Directory for output files.
output_directory (Optional[str]): Directory for output files. This can also be a relative path to the cwd or current working directory.
temp_directory (Optional[str]): Temporary directory for processing.
input_directory (Optional[str]): Directory for input files.
input_directory (Optional[str]): Directory for input files. When this is a relative path, it will be looked up relative to the cwd (current working directory) and all of the base_paths.
auto_launch (bool): Auto-launch UI in the default browser. Defaults to False.
disable_auto_launch (bool): Disable auto-launching the browser.
cuda_device (Optional[int]): CUDA device ID. None means default device.
@ -87,7 +90,6 @@ class Configuration(dict):
reserve_vram (Optional[float]): Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS
disable_smart_memory (bool): Disable smart memory management.
deterministic (bool): Use deterministic algorithms where possible.
dont_print_server (bool): Suppress server output.
quick_test_for_ci (bool): Enable quick testing mode for CI.
windows_standalone_build (bool): Enable features for standalone Windows build.
disable_metadata (bool): Disable saving metadata with outputs.
@ -103,7 +105,7 @@ class Configuration(dict):
distributed_queue_worker (bool): Workers will pull requests off the AMQP URL.
distributed_queue_name (str): This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID.
external_address (str): Specifies a base URL for external addresses reported by the API, such as for image paths.
verbose (bool | str): Shows extra output for debugging purposes such as import errors of custom nodes; or, specifies a log level
logging_level (str): Specifies a log level
disable_known_models (bool): Disables automatic downloads of known models and prevents them from appearing in the UI.
max_queue_size (int): The API will reject prompt requests if the queue's size exceeds this value.
otel_service_name (str): The name of the service or application that is generating telemetry data. Default: "comfyui".
@ -122,6 +124,7 @@ class Configuration(dict):
self._observers: List[ConfigObserver] = []
self.config_files = []
self.cwd: Optional[str] = None
self.base_paths: list[Path] = []
self.listen: str = "127.0.0.1"
self.port: int = 8188
self.enable_cors_header: Optional[str] = None
@ -192,7 +195,7 @@ class Configuration(dict):
self.force_channels_last: bool = False
self.force_hf_local_dir_mode = False
self.preview_size: int = 512
self.verbose: str | bool = "INFO"
self.logging_level: str = "INFO"
# from guill
self.cache_lru: int = 0
@ -253,6 +256,17 @@ class Configuration(dict):
self.update(state)
self._observers = []
@property
def verbose(self) -> str:
return self.logging_level
@verbose.setter
def verbose(self, value):
if isinstance(value, bool):
self.logging_level = "DEBUG"
else:
self.logging_level = value
class EnumAction(argparse.Action):
"""

View File

@ -1,13 +1,14 @@
from __future__ import annotations
import asyncio
import contextvars
import gc
import json
import os
import threading
import uuid
from asyncio import get_event_loop
from concurrent.futures import ThreadPoolExecutor
from multiprocessing import RLock
from pathlib import Path
from typing import Optional
from opentelemetry import context, propagate
@ -17,13 +18,16 @@ from opentelemetry.trace import Status, StatusCode
from .client_types import V1QueuePromptResponse
from ..api.components.schema.prompt import PromptDict
from ..cli_args_types import Configuration
from ..cmd.folder_paths import init_default_paths
from ..cmd.main_pre import tracer
from ..component_model.executor_types import ExecutorToClientProgress, Executor
from ..component_model.executor_types import ExecutorToClientProgress
from ..component_model.make_mutable import make_mutable
from ..distributed.executors import ContextVarExecutor
from ..distributed.process_pool_executor import ProcessPoolExecutor
from ..distributed.server_stub import ServerStub
from ..execution_context import current_execution_context
_prompt_executor = contextvars.ContextVar('prompt_executor')
_prompt_executor = threading.local()
def _execute_prompt(
@ -33,6 +37,9 @@ def _execute_prompt(
span_context: dict,
progress_handler: ExecutorToClientProgress | None,
configuration: Configuration | None) -> dict:
execution_context = current_execution_context()
if len(execution_context.folder_names_and_paths) == 0 or configuration is not None:
init_default_paths(execution_context.folder_names_and_paths, configuration)
span_context: Context = propagate.extract(span_context)
token = attach(span_context)
try:
@ -52,8 +59,8 @@ def __execute_prompt(
progress_handler = progress_handler or ServerStub()
try:
prompt_executor = _prompt_executor.get()
except LookupError:
prompt_executor: PromptExecutor = _prompt_executor.executor
except (LookupError, AttributeError):
if configuration is None:
options.enable_args_parsing()
else:
@ -65,7 +72,7 @@ def __execute_prompt(
with tracer.start_as_current_span("Initialize Prompt Executor", context=span_context) as span:
prompt_executor = PromptExecutor(progress_handler, lru_size=configuration.cache_lru if configuration is not None else 0)
prompt_executor.raise_exceptions = True
_prompt_executor.set(prompt_executor)
_prompt_executor.executor = prompt_executor
with tracer.start_as_current_span("Execute Prompt", context=span_context) as span:
try:
@ -96,6 +103,13 @@ def __execute_prompt(
def _cleanup():
from ..cmd.execution import PromptExecutor
try:
prompt_executor: PromptExecutor = _prompt_executor.executor
# this should clear all references to output tensors and make it easier to collect back the memory
prompt_executor.reset()
except (LookupError, AttributeError):
pass
from .. import model_management
model_management.unload_all_models()
gc.collect()
@ -139,9 +153,9 @@ class EmbeddedComfyClient:
In order to use this in blocking methods, learn more about asyncio online.
"""
def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: Executor = None):
def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: ProcessPoolExecutor | ContextVarExecutor = None):
self._progress_handler = progress_handler or ServerStub()
self._executor = executor or ThreadPoolExecutor(max_workers=max_workers)
self._executor = executor or ContextVarExecutor(max_workers=max_workers)
self._configuration = configuration
self._is_running = False
self._task_count_lock = RLock()

View File

@ -27,7 +27,7 @@ from ..component_model.executor_types import ExecutorToClientProgress, Validatio
HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage
from ..component_model.files import canonicalize_path
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
from ..execution_context import new_execution_context, context_execute_node, ExecutionContext
from ..execution_context import context_execute_node, context_execute_prompt
from ..nodes.package import import_all_nodes_in_workspace
from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
@ -77,24 +77,19 @@ class IsChangedCache:
class CacheSet:
def __init__(self, lru_size=None):
if lru_size is None or lru_size == 0:
self.init_classic_cache()
# Performs like the old cache -- dump data ASAP
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID)
else:
self.init_lru_cache(lru_size)
# Useful for those with ample RAM/VRAM -- allows experimenting without
# blowing away the cache every time
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=lru_size)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=lru_size)
self.objects = HierarchicalCache(CacheKeySetID)
self.all = [self.outputs, self.ui, self.objects]
# Useful for those with ample RAM/VRAM -- allows experimenting without
# blowing away the cache every time
def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)
# Performs like the old cache -- dump data ASAP
def init_classic_cache(self):
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID)
def recursive_debug_dump(self):
result = {
"outputs": self.outputs.recursive_debug_dump(),
@ -308,11 +303,11 @@ def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches,
:param pending_subgraph_results:
:return:
"""
with context_execute_node(_node_id, prompt_id):
with context_execute_node(_node_id):
return _execute(server, dynprompt, caches, _node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
def _execute(server, dynprompt, caches, current_item: str, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) -> RecursiveExecutionTuple:
def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) -> RecursiveExecutionTuple:
unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id)
@ -548,7 +543,7 @@ class PromptExecutor:
# torchao and potentially other optimization approaches break when the models are created in inference mode
# todo: this should really be backpropagated to code which creates ModelPatchers via lazy evaluation rather than globally checked here
inference_mode = all(not hasattr(node_class, "INFERENCE_MODE") or node_class.INFERENCE_MODE for node_class in iterate_obj_classes(prompt))
with new_execution_context(ExecutionContext(self.server, task_id=prompt_id, inference_mode=inference_mode)):
with context_execute_prompt(self.server, prompt_id, inference_mode=inference_mode):
self._execute_inner(prompt, prompt_id, extra_data, execute_outputs)
def _execute_inner(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None, inference_mode: bool = True):

View File

@ -4,142 +4,144 @@ import logging
import mimetypes
import os
import time
from typing import Optional, List, Final, Literal
from contextlib import nullcontext
from pathlib import Path
from typing import Optional, List, Literal
from .folder_paths_pre import get_base_path
from ..cli_args_types import Configuration
from ..component_model.deprecation import _deprecate_method
from ..component_model.files import get_package_as_path
from ..component_model.folder_path_types import FolderPathsTuple, FolderNames, SaveImagePathResponse
from ..component_model.folder_path_types import extension_mimetypes_cache as _extension_mimetypes_cache
from ..component_model.folder_path_types import supported_pt_extensions as _supported_pt_extensions
from ..component_model.module_property import module_property
from ..component_model.folder_path_types import FolderNames, SaveImagePathTuple, ModelPaths
from ..component_model.folder_path_types import supported_pt_extensions, extension_mimetypes_cache
from ..component_model.module_property import create_module_properties
from ..component_model.platform_path import construct_path
from ..execution_context import current_execution_context
supported_pt_extensions: Final[frozenset[str]] = _supported_pt_extensions
extension_mimetypes_cache: Final[dict[str, str]] = _extension_mimetypes_cache
_module_properties = create_module_properties()
@_module_properties.getter
def _supported_pt_extensions() -> frozenset[str]:
return supported_pt_extensions
@_module_properties.getter
def _extension_mimetypes_cache() -> dict[str, str]:
return extension_mimetypes_cache
# todo: this needs to be wrapped in a context and configurable
@module_property
@_module_properties.getter
def _base_path():
return get_base_path()
return _folder_names_and_paths().base_paths[0]
models_dir = os.path.join(get_base_path(), "models")
folder_names_and_paths: Final[FolderNames] = FolderNames(models_dir)
folder_names_and_paths["checkpoints"] = FolderPathsTuple("checkpoints", [os.path.join(models_dir, "checkpoints")], set(supported_pt_extensions))
folder_names_and_paths["configs"] = FolderPathsTuple("configs", [os.path.join(models_dir, "configs"), get_package_as_path("comfy.configs")], {".yaml"})
folder_names_and_paths["loras"] = FolderPathsTuple("loras", [os.path.join(models_dir, "loras")], set(supported_pt_extensions))
folder_names_and_paths["vae"] = FolderPathsTuple("vae", [os.path.join(models_dir, "vae")], set(supported_pt_extensions))
folder_names_and_paths["clip"] = FolderPathsTuple("clip", [os.path.join(models_dir, "clip")], set(supported_pt_extensions))
folder_names_and_paths["unet"] = folder_names_and_paths["diffusion_models"] = FolderPathsTuple("diffusion_models", [os.path.join(models_dir, "unet"), os.path.join(models_dir, "diffusion_models")], set(supported_pt_extensions))
folder_names_and_paths["clip_vision"] = FolderPathsTuple("clip_vision", [os.path.join(models_dir, "clip_vision")], set(supported_pt_extensions))
folder_names_and_paths["style_models"] = FolderPathsTuple("style_models", [os.path.join(models_dir, "style_models")], set(supported_pt_extensions))
folder_names_and_paths["embeddings"] = FolderPathsTuple("embeddings", [os.path.join(models_dir, "embeddings")], set(supported_pt_extensions))
folder_names_and_paths["diffusers"] = FolderPathsTuple("diffusers", [os.path.join(models_dir, "diffusers")], {"folder"})
folder_names_and_paths["vae_approx"] = FolderPathsTuple("vae_approx", [os.path.join(models_dir, "vae_approx")], set(supported_pt_extensions))
folder_names_and_paths["controlnet"] = FolderPathsTuple("controlnet", [os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], set(supported_pt_extensions))
folder_names_and_paths["gligen"] = FolderPathsTuple("gligen", [os.path.join(models_dir, "gligen")], set(supported_pt_extensions))
folder_names_and_paths["upscale_models"] = FolderPathsTuple("upscale_models", [os.path.join(models_dir, "upscale_models")], set(supported_pt_extensions))
folder_names_and_paths["custom_nodes"] = FolderPathsTuple("custom_nodes", [os.path.join(get_base_path(), "custom_nodes")], set())
folder_names_and_paths["hypernetworks"] = FolderPathsTuple("hypernetworks", [os.path.join(models_dir, "hypernetworks")], set(supported_pt_extensions))
folder_names_and_paths["photomaker"] = FolderPathsTuple("photomaker", [os.path.join(models_dir, "photomaker")], set(supported_pt_extensions))
folder_names_and_paths["classifiers"] = FolderPathsTuple("classifiers", [os.path.join(models_dir, "classifiers")], {""})
folder_names_and_paths["huggingface"] = FolderPathsTuple("huggingface", [os.path.join(models_dir, "huggingface")], {""})
folder_names_and_paths["huggingface_cache"] = FolderPathsTuple("huggingface_cache", [os.path.join(models_dir, "huggingface_cache")], {""})
output_directory = os.path.join(get_base_path(), "output")
temp_directory = os.path.join(get_base_path(), "temp")
input_directory = os.path.join(get_base_path(), "input")
user_directory = os.path.join(get_base_path(), "user")
filename_list_cache = {}
def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optional[Configuration] = None):
from ..cmd.main_pre import args
configuration = configuration or args
base_paths = [Path(configuration.cwd) if configuration.cwd is not None else None] + configuration.base_paths
base_paths = [path for path in base_paths if path is not None]
if len(base_paths) == 0:
base_paths = [Path(os.getcwd())]
for base_path in base_paths:
folder_names_and_paths.add_base_path(base_path)
folder_names_and_paths.add(ModelPaths(["checkpoints"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["configs"], additional_absolute_directory_paths={get_package_as_path("comfy.configs")}, supported_extensions={".yaml"}))
folder_names_and_paths.add(ModelPaths(["vae"], supported_extensions={".yaml"}))
folder_names_and_paths.add(ModelPaths(["clip"], supported_extensions={".yaml"}))
folder_names_and_paths.add(ModelPaths(["loras"], supported_extensions={".yaml"}))
folder_names_and_paths.add(ModelPaths(["diffusion_models", "unet"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["clip_vision"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["style_models"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["embeddings"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["diffusers"], supported_extensions=set()))
folder_names_and_paths.add(ModelPaths(["vae_approx"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["controlnet", "t2i_adapter"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["gligen"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["upscale_models"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["custom_nodes"], folder_name_base_path_subdir=construct_path(""), supported_extensions=set()))
folder_names_and_paths.add(ModelPaths(["hypernetworks"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["photomaker"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["classifiers"], supported_extensions=set()))
folder_names_and_paths.add(ModelPaths(["huggingface"], supported_extensions=set()))
hf_cache_paths = ModelPaths(["huggingface_cache"], supported_extensions=set())
# TODO: explore if there is a better way to do this
if "HF_HUB_CACHE" in os.environ:
hf_cache_paths.additional_absolute_directory_paths.add(os.environ.get("HF_HUB_CACHE"))
folder_names_and_paths.add(hf_cache_paths)
create_directories(folder_names_and_paths)
class CacheHelper:
"""
Helper class for managing file list cache data.
"""
def __init__(self):
self.cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
self.active = False
def get(self, key: str, default=None) -> tuple[list[str], dict[str, float], float]:
if not self.active:
return default
return self.cache.get(key, default)
def set(self, key: str, value: tuple[list[str], dict[str, float], float]) -> None:
if self.active:
self.cache[key] = value
def clear(self):
self.cache.clear()
def __enter__(self):
self.active = True
return self
def __exit__(self, exc_type, exc_value, traceback):
self.active = False
self.clear()
@_module_properties.getter
def _folder_names_and_paths():
return current_execution_context().folder_names_and_paths
cache_helper = CacheHelper()
@_module_properties.getter
def _models_dir():
return str(Path(current_execution_context().folder_names_and_paths.base_paths[0]) / construct_path("models"))
@_module_properties.getter
def _user_directory() -> str:
return str(Path(current_execution_context().folder_names_and_paths.application_paths.user_directory).resolve())
@_module_properties.getter
def _temp_directory() -> str:
return str(Path(current_execution_context().folder_names_and_paths.application_paths.temp_directory).resolve())
@_module_properties.getter
def _input_directory() -> str:
return str(Path(current_execution_context().folder_names_and_paths.application_paths.input_directory).resolve())
@_module_properties.getter
def _output_directory() -> str:
return str(Path(current_execution_context().folder_names_and_paths.application_paths.output_directory).resolve())
@_deprecate_method(version="0.2.3", message="Mapping of previous folder names is already done by other mechanisms.")
def map_legacy(folder_name: str) -> str:
legacy = {"unet": "diffusion_models"}
return legacy.get(folder_name, folder_name)
if not os.path.exists(input_directory):
try:
os.makedirs(input_directory)
except:
logging.error("Failed to create input directory")
def set_output_directory(output_dir: str | Path):
_folder_names_and_paths().application_paths.output_directory = construct_path(output_dir)
def set_output_directory(output_dir):
global output_directory
output_directory = output_dir
def set_temp_directory(temp_dir: str | Path):
_folder_names_and_paths().application_paths.temp_directory = construct_path(temp_dir)
def set_temp_directory(temp_dir):
global temp_directory
temp_directory = temp_dir
def set_input_directory(input_dir: str | Path):
_folder_names_and_paths().application_paths.input_directory = construct_path(input_dir)
def set_input_directory(input_dir):
global input_directory
input_directory = input_dir
def get_output_directory() -> str:
return str(Path(_folder_names_and_paths().application_paths.output_directory).resolve())
def get_output_directory():
global output_directory
return output_directory
def get_temp_directory() -> str:
return str(Path(_folder_names_and_paths().application_paths.temp_directory).resolve())
def get_temp_directory():
global temp_directory
return temp_directory
def get_input_directory():
global input_directory
return input_directory
def get_input_directory() -> str:
return str(Path(_folder_names_and_paths().application_paths.input_directory).resolve())
def get_user_directory() -> str:
return user_directory
return str(Path(_folder_names_and_paths().application_paths.user_directory).resolve())
def set_user_directory(user_dir: str) -> None:
global user_directory
user_directory = user_dir
def set_user_directory(user_dir: str | Path) -> None:
_folder_names_and_paths().application_paths.user_directory = construct_path(user_dir)
# NOTE: used in http server so don't put folders that should not be accessed remotely
def get_directory_by_type(type_name):
def get_directory_by_type(type_name) -> str | None:
if type_name == "output":
return get_output_directory()
if type_name == "temp":
@ -151,7 +153,7 @@ def get_directory_by_type(type_name):
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
# otherwise use default_path as base_dir
def annotated_filepath(name):
def annotated_filepath(name: str) -> tuple[str, str | None]:
if name.endswith("[output]"):
base_dir = get_output_directory()
name = name[:-9]
@ -167,7 +169,7 @@ def annotated_filepath(name):
return name, base_dir
def get_annotated_filepath(name, default_dir=None):
def get_annotated_filepath(name, default_dir=None) -> str:
name, base_dir = annotated_filepath(name)
if base_dir is None:
@ -198,9 +200,11 @@ def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None, e
:param extensions: supported file extensions
:return: the folder path
"""
global folder_names_and_paths
folder_names_and_paths = _folder_names_and_paths()
if full_folder_path is None:
full_folder_path = os.path.join(models_dir, folder_name)
# todo: this should use the subdir patter
full_folder_path = os.path.join(_models_dir(), folder_name)
folder_path = folder_names_and_paths[folder_name]
if full_folder_path not in folder_path.paths:
@ -212,48 +216,16 @@ def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None, e
if extensions is not None:
folder_path.supported_extensions |= extensions
invalidate_cache(folder_name)
return full_folder_path
def get_folder_paths(folder_name) -> List[str]:
return folder_names_and_paths[folder_name].paths[:]
return [path for path in _folder_names_and_paths()[folder_name].paths]
def recursive_search(directory, excluded_dir_names=None):
if not os.path.isdir(directory):
return [], {}
if excluded_dir_names is None:
excluded_dir_names = []
result = []
dirs = {}
# Attempt to add the initial directory to dirs with error handling
try:
dirs[directory] = os.path.getmtime(directory)
except FileNotFoundError:
logging.warning(f"Warning: Unable to access {directory}. Skipping this path.")
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
for file_name in filenames:
try:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path)
except:
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
continue
for d in subdirs:
path = os.path.join(dirpath, d)
try:
dirs[path] = os.path.getmtime(path)
except FileNotFoundError:
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
continue
return result, dirs
@_deprecate_method(version="0.2.3", message="Not supported")
def recursive_search(directory, excluded_dir_names=None) -> tuple[list[str], dict[str, float]]:
raise NotImplemented("Unsupported method")
def filter_files_extensions(files, extensions):
@ -264,92 +236,27 @@ def get_full_path(folder_name, filename) -> Optional[str | bytes | os.PathLike]:
"""
Gets the path to a filename inside a folder.
Works with untrusted filenames.
:param folder_name:
:param filename:
:return:
"""
global folder_names_and_paths
folders = folder_names_and_paths[folder_name].paths
filename_split = os.path.split(filename)
trusted_paths = []
for folder in folders:
folder_split = os.path.split(folder)
abs_file_path = os.path.abspath(os.path.join(*folder_split, *filename_split))
abs_folder_path = os.path.abspath(folder)
if os.path.commonpath([abs_file_path, abs_folder_path]) == abs_folder_path:
trusted_paths.append(abs_file_path)
else:
logging.error(f"attempted to access untrusted path {abs_file_path} in {folder_name} for filename {filename}")
for trusted_path in trusted_paths:
if os.path.isfile(trusted_path):
return trusted_path
return None
path = _folder_names_and_paths().first_existing_or_none(folder_name, construct_path(filename))
return str(path) if path is not None else None
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
full_path = get_full_path(folder_name, filename)
if full_path is None:
# todo: probably shouldn't say model
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
return full_path
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
folder_name = map_legacy(folder_name)
global folder_names_and_paths
output_list = set()
folders = folder_names_and_paths[folder_name]
output_folders = {}
for x in folders[0]:
files, folders_all = recursive_search(x, excluded_dir_names=[".git"])
output_list.update(filter_files_extensions(files, folders[1]))
output_folders = {**output_folders, **folders_all}
return sorted(list(output_list)), output_folders, time.perf_counter()
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
strong_cache = cache_helper.get(folder_name)
if strong_cache is not None:
return strong_cache
global filename_list_cache
global folder_names_and_paths
folder_name = map_legacy(folder_name)
if folder_name not in filename_list_cache:
return None
out = filename_list_cache[folder_name]
for x in out[1]:
time_modified = out[1][x]
folder = x
if os.path.getmtime(folder) != time_modified:
return None
folders = folder_names_and_paths[folder_name]
for x in folders[0]:
if os.path.isdir(x):
if x not in out[1]:
return None
return out
def get_filename_list(folder_name: str) -> list[str]:
folder_name = map_legacy(folder_name)
out = cached_filename_list_(folder_name)
if out is None:
out = get_filename_list_(folder_name)
global filename_list_cache
filename_list_cache[folder_name] = out
cache_helper.set(folder_name, out)
return list(out[0])
return [str(path) for path in _folder_names_and_paths().file_paths(folder_name=folder_name, relative=True)]
def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0):
def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0) -> SaveImagePathTuple:
def map_filename(filename: str) -> tuple[int, str]:
prefix_len = len(os.path.basename(filename_prefix))
prefix = filename[:prefix_len + 1]
@ -378,14 +285,6 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height
full_output_folder = str(os.path.join(output_dir, subfolder))
if str(os.path.commonpath((output_dir, os.path.abspath(full_output_folder)))) != str(output_dir):
err = f"""**** ERROR: Saving image outside the output folder is not allowed.
full_output_folder: {os.path.abspath(full_output_folder)}
output_dir: {output_dir}
commonpath: {os.path.commonpath((output_dir, os.path.abspath(full_output_folder)))}"""
logging.error(err)
raise Exception(err)
try:
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
except ValueError:
@ -393,21 +292,22 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height
except FileNotFoundError:
os.makedirs(full_output_folder, exist_ok=True)
counter = 1
return SaveImagePathResponse(full_output_folder, filename, counter, subfolder, filename_prefix)
return SaveImagePathTuple(full_output_folder, filename, counter, subfolder, filename_prefix)
def create_directories():
def create_directories(paths: FolderNames | None):
# all configured paths should be created
for folder_path_spec in folder_names_and_paths.values():
paths = paths or _folder_names_and_paths()
for folder_path_spec in paths.values():
for path in folder_path_spec.paths:
os.makedirs(path, exist_ok=True)
for path in (temp_directory, input_directory, output_directory, user_directory):
os.makedirs(path, exist_ok=True)
for path in paths.application_paths:
path.mkdir(exist_ok=True)
@_deprecate_method(version="0.2.3", message="Caching has been removed.")
def invalidate_cache(folder_name):
global filename_list_cache
filename_list_cache.pop(folder_name, None)
pass
def filter_files_content_types(files: list[str], content_types: list[Literal["image", "video", "audio"]]) -> list[str]:
@ -416,7 +316,7 @@ def filter_files_content_types(files: list[str], content_types: list[Literal["im
files = os.listdir(folder_paths.get_input_directory())
filter_files_content_types(files, ["image", "audio", "video"])
"""
global extension_mimetypes_cache
extension_mimetypes_cache = _extension_mimetypes_cache()
result = []
for file in files:
extension = file.split('.')[-1]
@ -432,3 +332,52 @@ def filter_files_content_types(files: list[str], content_types: list[Literal["im
if content_type in content_types:
result.append(file)
return result
@_module_properties.getter
def _cache_helper():
return nullcontext()
# todo: can this be done side effect free?
init_default_paths(_folder_names_and_paths())
__all__ = [
# Properties (stripped leading underscore)
"supported_pt_extensions", # from _supported_pt_extensions
"extension_mimetypes_cache", # from _extension_mimetypes_cache
"base_path", # from _base_path
"folder_names_and_paths", # from _folder_names_and_paths
"models_dir", # from _models_dir
"user_directory",
"output_directory",
"temp_directory",
"input_directory",
# Public functions
"init_default_paths",
"map_legacy",
"set_output_directory",
"set_temp_directory",
"set_input_directory",
"get_output_directory",
"get_temp_directory",
"get_input_directory",
"get_user_directory",
"set_user_directory",
"get_directory_by_type",
"annotated_filepath",
"get_annotated_filepath",
"exists_annotated_filepath",
"add_model_folder_path",
"get_folder_paths",
"recursive_search",
"filter_files_extensions",
"get_full_path",
"get_full_path_or_raise",
"get_filename_list",
"get_save_image_path",
"create_directories",
"invalidate_cache",
"filter_files_content_types"
]

View File

@ -1,31 +0,0 @@
import logging
import os
from ..cli_args import args
_base_path = None
# todo: this should be initialized elsewhere in a context
def get_base_path() -> str:
global _base_path
if _base_path is None:
if args.cwd is not None:
if not os.path.exists(args.cwd):
try:
os.makedirs(args.cwd, exist_ok=True)
except:
logging.error("Failed to create custom working directory")
# wrap the path to prevent slashedness from glitching out common path checks
_base_path = os.path.realpath(args.cwd)
else:
_base_path = os.getcwd()
return _base_path
def set_base_path(value: str):
global _base_path
_base_path = value
__all__ = ["get_base_path", "set_base_path"]

View File

@ -1,4 +1,5 @@
import asyncio
import contextvars
import gc
import itertools
import logging
@ -193,7 +194,9 @@ async def main(from_script_dir: Optional[Path] = None):
if not distributed or args.distributed_queue_worker:
if distributed:
logging.warning(f"Distributed workers started in the default thread loop cannot notify clients of progress updates. Instead of comfyui or main.py, use comfyui-worker.")
threading.Thread(target=prompt_worker, daemon=True, args=(q, worker_thread_server,)).start()
# todo: this should really be using an executor instead of doing things this jankilicious way
ctx = contextvars.copy_context()
threading.Thread(target=lambda _q, _worker_thread_server: ctx.run(prompt_worker, _q, _worker_thread_server), daemon=True, args=(q, worker_thread_server,)).start()
# server has been imported and things should be looking good
initialize_event_tracking(loop)

View File

@ -111,13 +111,7 @@ def _create_tracer():
def _configure_logging():
if isinstance(args.verbose, str):
logging_level = args.verbose
elif args.verbose == True:
logging_level = logging.DEBUG
else:
logging_level = logging.ERROR
logging_level = args.logging_level
logging.basicConfig(format="%(message)s", level=logging_level)

View File

@ -28,8 +28,8 @@ from aiohttp import web
from can_ada import URL, parse as urlparse # pylint: disable=no-name-in-module
from typing_extensions import NamedTuple
from .. import __version__
from .latent_preview_image_encoding import encode_preview_image
from .. import __version__
from .. import interruption, model_management
from .. import node_helpers
from .. import utils
@ -241,7 +241,7 @@ class PromptServer(ExecutorToClientProgress):
return response
@routes.get("/embeddings")
def get_embeddings(self):
def get_embeddings(request):
embeddings = folder_paths.get_filename_list("embeddings")
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
@ -568,15 +568,14 @@ class PromptServer(ExecutorToClientProgress):
@routes.get("/object_info")
async def get_object_info(request):
with folder_paths.cache_helper:
out = {}
for x in self.nodes.NODE_CLASS_MAPPINGS:
try:
out[x] = node_info(x)
except Exception as e:
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
logging.error(traceback.format_exc())
return web.json_response(out)
out = {}
for x in self.nodes.NODE_CLASS_MAPPINGS:
try:
out[x] = node_info(x)
except Exception as e:
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
logging.error(traceback.format_exc())
return web.json_response(out)
@routes.get("/object_info/{node_class}")
async def get_object_info_node(request):
@ -721,21 +720,21 @@ class PromptServer(ExecutorToClientProgress):
async def get_api_v1_prompts_prompt_id(request: web.Request) -> web.Response | web.FileResponse:
prompt_id: str = request.match_info.get("prompt_id", "")
if prompt_id == "":
return web.Response(status=404)
return web.json_response(status=404)
history_items = self.prompt_queue.get_history(prompt_id)
if len(history_items) == 0 or prompt_id not in history_items:
# todo: this should really be moved to a stateful queue abstraction
if prompt_id in self.background_tasks:
return web.Response(status=204)
return web.json_response(status=204)
else:
# todo: this should check a stateful queue abstraction
return web.Response(status=404)
return web.json_response(status=404)
elif prompt_id in history_items:
history_entry = history_items[prompt_id]
return web.json_response(history_entry["outputs"])
else:
return web.Response(status=500)
return web.json_response(status=500)
@routes.post("/api/v1/prompts")
async def post_api_prompt(request: web.Request) -> web.Response | web.FileResponse:
@ -751,8 +750,8 @@ class PromptServer(ExecutorToClientProgress):
queue_size = self.prompt_queue.size()
queue_too_busy_size = PromptServer.get_too_busy_queue_size()
if queue_size > queue_too_busy_size:
return web.Response(status=429,
reason=f"the queue has {queue_size} elements and {queue_too_busy_size} is the limit for this worker")
return web.json_response(status=429,
reason=f"the queue has {queue_size} elements and {queue_too_busy_size} is the limit for this worker")
# read the request
prompt_dict: dict = {}
if content_type == 'application/json':

View File

@ -2,10 +2,11 @@ import asyncio
import itertools
import logging
import os
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
from concurrent.futures import ProcessPoolExecutor
from .extra_model_paths import load_extra_path_config
from .main_pre import args
from .extra_model_paths import load_extra_path_config
from ..distributed.executors import ContextVarExecutor
async def main():
@ -43,7 +44,7 @@ async def main():
from ..distributed.distributed_prompt_worker import DistributedPromptWorker
async with DistributedPromptWorker(connection_uri=args.distributed_queue_connection_uri,
queue_name=args.distributed_queue_name,
executor=ThreadPoolExecutor(max_workers=1) if args.executor_factory == "ThreadPoolExecutor" else ProcessPoolExecutor(max_workers=1)):
executor=ContextVarExecutor(max_workers=1) if args.executor_factory == "ThreadPoolExecutor" else ProcessPoolExecutor(max_workers=1)):
stop = asyncio.Event()
try:
await stop.wait()

View File

@ -8,6 +8,7 @@ from typing import Optional, Literal, Protocol, Union, NamedTuple, List
import PIL.Image
from typing_extensions import NotRequired, TypedDict
from .outputs_types import OutputsDict
from .queue_types import BinaryEventTypes
from ..cli_args_types import Configuration
from ..nodes.package_typing import InputTypeSpec
@ -205,8 +206,8 @@ class DuplicateNodeError(Exception):
class HistoryResultDict(TypedDict, total=True):
outputs: dict
meta: dict
outputs: OutputsDict
meta: OutputsDict
class DependencyCycleError(Exception):

View File

@ -1,15 +1,9 @@
import os
from typing import Literal, Optional
from pathlib import Path
from typing import Literal, Optional
from ..cmd import folder_paths
def _is_strictly_below_root(path: Path) -> bool:
resolved_path = path.resolve()
return ".." not in resolved_path.parts and resolved_path.is_absolute()
def file_output_path(filename: str, type: Literal["input", "output", "temp"] = "output",
subfolder: Optional[str] = None) -> str:
"""
@ -22,22 +16,15 @@ def file_output_path(filename: str, type: Literal["input", "output", "temp"] = "
:return:
"""
filename, output_dir = folder_paths.annotated_filepath(str(filename))
if not _is_strictly_below_root(Path(filename)):
raise PermissionError("insecure")
if output_dir is None:
output_dir = folder_paths.get_directory_by_type(type)
if output_dir is None:
raise ValueError(f"no such output directory because invalid type specified (type={type})")
if subfolder is not None and subfolder != "":
full_output_dir = str(os.path.join(output_dir, subfolder))
if str(os.path.commonpath([os.path.abspath(full_output_dir), output_dir])) != str(output_dir):
raise PermissionError("insecure")
output_dir = full_output_dir
filename = os.path.basename(filename)
else:
if str(os.path.commonpath([os.path.abspath(output_dir), os.path.join(output_dir, filename)])) != str(output_dir):
raise PermissionError("insecure")
file = os.path.join(output_dir, filename)
return file
output_dir = Path(output_dir)
subfolder = Path(subfolder or "")
try:
relative_to = (output_dir / subfolder / filename).relative_to(output_dir)
except ValueError:
raise PermissionError(f"{output_dir / subfolder / filename} is not a subpath of {output_dir}")
return str((output_dir / relative_to).resolve(strict=True))

View File

@ -1,9 +1,15 @@
from __future__ import annotations
import dataclasses
import itertools
import os
import typing
from typing import List, Set, Any, Iterator, Sequence, Dict, NamedTuple
import weakref
from abc import ABC, abstractmethod
from typing import Any, NamedTuple, Optional, Iterable
from pathlib import Path
from .platform_path import construct_path
supported_pt_extensions = frozenset(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft', ".index.json"])
extension_mimetypes_cache = {
@ -11,35 +17,267 @@ extension_mimetypes_cache = {
}
@dataclasses.dataclass
def do_add(collection: list | set, index: int | None, item: Any):
if isinstance(collection, list) and index == 0:
collection.insert(0, item)
elif isinstance(collection, set):
collection.add(item)
else:
assert isinstance(collection, list)
collection.append(item)
class FolderPathsTuple:
folder_name: str
paths: List[str] = dataclasses.field(default_factory=list)
supported_extensions: Set[str] = dataclasses.field(default_factory=lambda: set(supported_pt_extensions))
def __init__(self, folder_name: str = None, paths: list[str] = None, supported_extensions: set[str] = None, parent: Optional[weakref.ReferenceType[FolderNames]] = None):
paths = paths or []
supported_extensions = supported_extensions or set(supported_pt_extensions)
self.folder_name = folder_name
self.parent = parent
self._paths = paths
self._supported_extensions = supported_extensions
@property
def supported_extensions(self) -> set[str] | SupportedExtensions:
if self.parent is not None and self.folder_name is not None:
return SupportedExtensions(self.folder_name, self.parent)
else:
return self._supported_extensions
@supported_extensions.setter
def supported_extensions(self, value: set[str] | SupportedExtensions):
self.supported_extensions.clear()
self.supported_extensions.update(value)
@property
def paths(self) -> list[str] | PathsList:
if self.parent is not None and self.folder_name is not None:
return PathsList(self.folder_name, self.parent)
else:
return self._paths
def __iter__(self) -> typing.Generator[typing.Iterable[str]]:
"""
allows this proxy to behave like a tuple everywhere it is used by the custom nodes
:return:
"""
yield self.paths
yield self.supported_extensions
def __getitem__(self, item: Any):
if item == 0:
return self.paths
if item == 1:
return self.supported_extensions
else:
raise RuntimeError("unsupported tuple index")
def __add__(self, other: "FolderPathsTuple"):
assert self.folder_name == other.folder_name
# todo: make sure the paths are actually unique, as this method intends
new_paths = list(frozenset(self.paths + other.paths))
new_supported_extensions = self.supported_extensions | other.supported_extensions
return FolderPathsTuple(self.folder_name, new_paths, new_supported_extensions)
raise RuntimeError("unsupported tuple index")
def __iter__(self) -> Iterator[Sequence[str]]:
yield self.paths
yield self.supported_extensions
def __iadd__(self, other: FolderPathsTuple):
for path in other.paths:
self.paths.append(path)
for ext in other.supported_extensions:
self.supported_extensions.add(ext)
@dataclasses.dataclass
class PathsList:
folder_name: str
parent: weakref.ReferenceType[FolderNames]
def __iter__(self) -> typing.Generator[str]:
p: FolderNames = self.parent()
for path in p.directory_paths(self.folder_name):
try:
yield str(path.resolve())
except (OSError, AttributeError):
yield str(path)
def __getitem__(self, item: int):
paths = [x for x in self]
return paths[item]
def append(self, path_str: str):
p: FolderNames = self.parent()
p.add_paths(self.folder_name, [path_str])
def insert(self, path_str: str, index: int):
p: FolderNames = self.parent()
p.add_paths(self.folder_name, [path_str], index=index)
@dataclasses.dataclass
class SupportedExtensions:
folder_name: str
parent: weakref.ReferenceType[FolderNames]
def _append_any(self, other):
if other is None:
return
p: FolderNames = self.parent()
if isinstance(other, str):
other = {other}
p.add_supported_extension(self.folder_name, *other)
def __iter__(self) -> typing.Generator[str]:
p: FolderNames = self.parent()
for ext in p.supported_extensions(self.folder_name):
yield ext
def clear(self):
p: FolderNames = self.parent()
p.remove_all_supported_extensions(self.folder_name)
__ior__ = _append_any
add = _append_any
update = _append_any
@dataclasses.dataclass
class ApplicationPaths:
output_directory: Path = dataclasses.field(default_factory=lambda: construct_path("output"))
temp_directory: Path = dataclasses.field(default_factory=lambda: construct_path("temp"))
input_directory: Path = dataclasses.field(default_factory=lambda: construct_path("input"))
user_directory: Path = dataclasses.field(default_factory=lambda: construct_path("user"))
def __iter__(self) -> typing.Generator[Path]:
yield self.output_directory
yield self.temp_directory
yield self.input_directory
yield self.user_directory
@dataclasses.dataclass
class AbstractPaths(ABC):
folder_names: list[str]
supported_extensions: set[str] = dataclasses.field(default_factory=lambda: set(supported_pt_extensions))
@abstractmethod
def directory_paths(self, base_paths: Iterable[Path]) -> typing.Generator[Path]:
"""Generate directory paths based on the given base paths."""
pass
@abstractmethod
def file_paths(self, base_paths: Iterable[Path], relative=False) -> typing.Generator[Path]:
"""Generate file paths based on the given base paths."""
pass
@abstractmethod
def has_folder_name(self, folder_name: str) -> bool:
"""Check if the given folder name is in folder_names."""
pass
@dataclasses.dataclass
class ModelPaths(AbstractPaths):
folder_name_base_path_subdir: Path = dataclasses.field(default_factory=lambda: construct_path("models"))
additional_relative_directory_paths: set[Path] = dataclasses.field(default_factory=set)
additional_absolute_directory_paths: set[str | Path] = dataclasses.field(default_factory=set)
folder_names_are_relative_directory_paths_too: bool = dataclasses.field(default_factory=lambda: True)
def directory_paths(self, base_paths: Iterable[Path]) -> typing.Generator[Path]:
yielded_so_far: set[Path] = set()
for base_path in base_paths:
if self.folder_names_are_relative_directory_paths_too:
for folder_name in self.folder_names:
resolved_default_folder_name_path = base_path / self.folder_name_base_path_subdir / folder_name
if not resolved_default_folder_name_path in yielded_so_far:
yield resolved_default_folder_name_path
yielded_so_far.add(resolved_default_folder_name_path)
for additional_relative_directory_path in self.additional_relative_directory_paths:
resolved_additional_relative_path = base_path / additional_relative_directory_path
if not resolved_additional_relative_path in yielded_so_far:
yield resolved_additional_relative_path
yielded_so_far.add(resolved_additional_relative_path)
# resolve all paths
yielded_so_far = {path.resolve() for path in yielded_so_far}
for additional_absolute_path in self.additional_absolute_directory_paths:
try:
resolved_absolute_path = additional_absolute_path.resolve()
except (OSError, AttributeError):
resolved_absolute_path = additional_absolute_path
if not resolved_absolute_path in yielded_so_far:
yield resolved_absolute_path
yielded_so_far.add(resolved_absolute_path)
def file_paths(self, base_paths: Iterable[Path], relative=False) -> typing.Generator[Path]:
for path in self.directory_paths(base_paths):
for dirpath, dirnames, filenames in os.walk(path, followlinks=True):
if '.git' in dirnames:
dirnames.remove('.git')
for filename in filenames:
if any(filename.endswith(ext) for ext in self.supported_extensions):
result_path = construct_path(dirpath) / filename
if relative:
yield result_path.relative_to(path)
else:
yield result_path
def has_folder_name(self, folder_name: str) -> bool:
return folder_name in self.folder_names
@dataclasses.dataclass
class FolderNames:
application_paths: typing.Optional[ApplicationPaths] = dataclasses.field(default_factory=ApplicationPaths)
contents: list[AbstractPaths] = dataclasses.field(default_factory=list)
base_paths: list[Path] = dataclasses.field(default_factory=list)
def supported_extensions(self, folder_name: str) -> typing.Generator[str]:
for candidate in self.contents:
if candidate.has_folder_name(folder_name):
for supported_extensions in candidate.supported_extensions:
for supported_extension in supported_extensions:
yield supported_extension
def directory_paths(self, folder_name: str) -> typing.Generator[Path]:
for directory_path in itertools.chain.from_iterable([candidate.directory_paths(self.base_paths)
for candidate in self.contents if candidate.has_folder_name(folder_name)]):
yield directory_path
def file_paths(self, folder_name: str, relative=False) -> typing.Generator[Path]:
for file_path in itertools.chain.from_iterable([candidate.file_paths(self.base_paths, relative=relative)
for candidate in self.contents if candidate.has_folder_name(folder_name)]):
yield file_path
def first_existing_or_none(self, folder_name: str, relative_file_path: Path) -> Optional[Path]:
for directory_path in itertools.chain.from_iterable([candidate.directory_paths(self.base_paths)
for candidate in self.contents if candidate.has_folder_name(folder_name)]):
candidate_file_path: Path = construct_path(directory_path / relative_file_path)
try:
# todo: this should follow the symlink
if Path.exists(candidate_file_path):
return candidate_file_path
except OSError:
continue
return None
def add_supported_extension(self, folder_name: str, *supported_extensions: str | None):
if supported_extensions is None:
return
for candidate in self.contents:
if candidate.has_folder_name(folder_name):
candidate.supported_extensions.update(supported_extensions)
def remove_all_supported_extensions(self, folder_name: str):
for candidate in self.contents:
if candidate.has_folder_name(folder_name):
candidate.supported_extensions.clear()
def add(self, model_paths: AbstractPaths):
self.contents.append(model_paths)
def add_base_path(self, base_path: Path):
if base_path not in self.base_paths:
self.base_paths.append(base_path)
@staticmethod
def from_dict(folder_paths_dict: dict[str, tuple[typing.Sequence[str], Sequence[str]]] = None) -> FolderNames:
def from_dict(folder_paths_dict: dict[str, tuple[typing.Iterable[str], Iterable[str]]] = None) -> FolderNames:
"""
Turns a dictionary of
{
@ -50,65 +288,179 @@ class FolderNames:
:param folder_paths_dict: A dictionary
:return: A FolderNames object
"""
if folder_paths_dict is None:
return FolderNames(os.getcwd())
raise NotImplementedError()
fn = FolderNames(os.getcwd())
for folder_name, (paths, extensions) in folder_paths_dict.items():
paths_tuple = FolderPathsTuple(folder_name=folder_name, paths=list(paths), supported_extensions=set(extensions))
if folder_name in fn:
fn[folder_name] += paths_tuple
else:
fn[folder_name] = paths_tuple
return fn
def __init__(self, default_new_folder_path: str):
self.contents: Dict[str, FolderPathsTuple] = dict()
self.default_new_folder_path = default_new_folder_path
def __getitem__(self, item) -> FolderPathsTuple:
if not isinstance(item, str):
def __getitem__(self, folder_name) -> FolderPathsTuple:
if not isinstance(folder_name, str) or folder_name is None:
raise RuntimeError("expected folder path")
if item not in self.contents:
default_path = os.path.join(self.default_new_folder_path, item)
os.makedirs(default_path, exist_ok=True)
self.contents[item] = FolderPathsTuple(item, paths=[default_path], supported_extensions=set())
return self.contents[item]
# todo: it is probably never the intention to do this
try:
path = Path(folder_name)
if path.is_absolute():
folder_name = path.stem
except Exception:
pass
if not any(candidate.has_folder_name(folder_name) for candidate in self.contents):
self.add(ModelPaths(folder_names=[folder_name], folder_name_base_path_subdir=construct_path(), supported_extensions=set(), folder_names_are_relative_directory_paths_too=False))
return FolderPathsTuple(folder_name, parent=weakref.ref(self))
def __setitem__(self, key: str, value: FolderPathsTuple):
assert isinstance(key, str)
if isinstance(value, tuple):
def add_paths(self, folder_name: str, paths: list[Path | str], index: Optional[int] = None):
"""
Adds, but does not create, new model paths
:param folder_name:
:param paths:
:param index:
:return:
"""
for candidate in self.contents:
if candidate.has_folder_name(folder_name):
self._modify_model_paths(folder_name, paths, set(), candidate, index=index)
def _modify_model_paths(self, key: str, paths: Iterable[Path | str], supported_extensions: set[str], model_paths: AbstractPaths = None, index: Optional[int] = None) -> AbstractPaths:
model_paths = model_paths or ModelPaths([key],
supported_extensions=set(supported_extensions),
folder_names_are_relative_directory_paths_too=False)
if index is not None and index != 0:
raise ValueError(f"index was {index} but only 0 or None is supported")
for path in paths:
if isinstance(path, str):
path = construct_path(path)
# when given absolute paths, try to formulate them as relative paths anyway
if path.is_absolute():
for base_path in self.base_paths:
try:
relative_to_basepath = path.relative_to(base_path)
potential_folder_name = relative_to_basepath.stem
potential_subdir = relative_to_basepath.parent
# does the folder_name so far match the key?
# or have we not seen this folder before?
folder_name_not_already_in_contents = all(not candidate.has_folder_name(potential_folder_name) for candidate in self.contents)
if potential_folder_name == key or folder_name_not_already_in_contents:
# fix the subdir
model_paths.folder_name_base_path_subdir = potential_subdir
model_paths.folder_names_are_relative_directory_paths_too = True
if folder_name_not_already_in_contents:
do_add(model_paths.folder_names, index, potential_folder_name)
else:
# if the folder name doesn't match the key, check if we have ever seen a folder name that matches the key:
if model_paths.folder_names_are_relative_directory_paths_too:
if potential_subdir == model_paths.folder_name_base_path_subdir:
# then we want to add this to the folder name, because it's probably compatibility
do_add(model_paths.folder_names, index, potential_folder_name)
else:
# not this case
model_paths.folder_names_are_relative_directory_paths_too = False
if not model_paths.folder_names_are_relative_directory_paths_too:
model_paths.additional_relative_directory_paths.add(relative_to_basepath)
for resolve_folder_name in model_paths.folder_names:
model_paths.additional_relative_directory_paths.add(model_paths.folder_name_base_path_subdir / resolve_folder_name)
# since this was an absolute path that was a subdirectory of one of the base paths,
# we are done
break
except ValueError:
# this is not a subpath of the base path
pass
# if we got this far, none of the absolute paths were subdirectories of any base paths
# add it to our absolute paths
model_paths.additional_absolute_directory_paths.add(path)
else:
# since this is a relative path, peacefully add it to model_paths
potential_folder_name = path.stem
try:
relative_to_current_subdir = path.relative_to(model_paths.folder_name_base_path_subdir)
# if relative to the current subdir, we observe only one part, we're good to go
if len(relative_to_current_subdir.parts) == 1:
if potential_folder_name == key:
model_paths.folder_names_are_relative_directory_paths_too = True
else:
# if there already exists a folder_name by this name, do not add it, and switch to all relative paths
if any(candidate.has_folder_name(potential_folder_name) for candidate in self.contents):
model_paths.folder_names_are_relative_directory_paths_too = False
model_paths.additional_relative_directory_paths.add(path)
model_paths.folder_name_base_path_subdir = construct_path()
else:
do_add(model_paths.folder_names, index, potential_folder_name)
except ValueError:
# this means that the relative directory didn't contain the subdir so far
# something_not_models/key
if potential_folder_name == key:
model_paths.folder_name_base_path_subdir = path.parent
model_paths.folder_names_are_relative_directory_paths_too = True
else:
if any(candidate.has_folder_name(potential_folder_name) for candidate in self.contents):
model_paths.folder_names_are_relative_directory_paths_too = False
model_paths.additional_relative_directory_paths.add(path)
model_paths.folder_name_base_path_subdir = construct_path()
else:
do_add(model_paths.folder_names, index, potential_folder_name)
return model_paths
def __setitem__(self, key: str, value: tuple | FolderPathsTuple | AbstractPaths):
# remove all existing paths for this key
self.__delitem__(key)
if isinstance(value, AbstractPaths):
self.contents.append(value)
elif isinstance(value, (tuple, FolderPathsTuple)):
paths, supported_extensions = value
value = FolderPathsTuple(key, paths, supported_extensions)
self.contents[key] = value
# typical cases:
# key="checkpoints", paths="C:/base_path/models/checkpoints"
# key="unets", paths="C:/base_path/models/unets", "C:/base_path/models/diffusion_models"
# ^ in this case, we will want folder_names to be both unets and diffusion_models
# key="custom_loader", paths="C:/base_path/models/checkpoints"
# if the paths are subdirectories of any basepath, use relative paths
paths: list[Path] = list(map(Path, paths))
self.contents.append(self._modify_model_paths(key, paths, supported_extensions))
def __len__(self):
return len(self.contents)
def __iter__(self):
return iter(self.contents)
def __iter__(self) -> typing.Generator[str]:
for model_paths in self.contents:
for folder_name in model_paths.folder_names:
yield folder_name
def __delitem__(self, key):
del self.contents[key]
to_remove: list[AbstractPaths] = []
for model_paths in self.contents:
if model_paths.has_folder_name(key):
to_remove.append(model_paths)
def __contains__(self, item):
return item in self.contents
for model_paths in to_remove:
self.contents.remove(model_paths)
def __contains__(self, item: str):
return any(candidate.has_folder_name(item) for candidate in self.contents)
def items(self):
return self.contents.items()
items_view = {
folder_name: self[folder_name] for folder_name in self.keys()
}
return items_view.items()
def values(self):
return self.contents.values()
return [self[folder_name] for folder_name in self.keys()]
def keys(self):
return self.contents.keys()
return [x for x in self]
def get(self, key, __default=None):
return self.contents.get(key, __default)
for candidate in self.contents:
if candidate.has_folder_name(key):
return FolderPathsTuple(key, parent=weakref.ref(self))
if __default is not None:
raise ValueError("get with default is not supported")
class SaveImagePathResponse(NamedTuple):
class SaveImagePathTuple(NamedTuple):
full_output_folder: str
filename: str
counter: int

View File

@ -1,22 +1,44 @@
import sys
from functools import wraps
def create_module_properties():
properties = {}
patched_modules = set()
def module_property(func):
"""Decorator to turn module functions into properties.
Function names must be prefixed with an underscore."""
module = sys.modules[func.__module__]
def patch_module(module):
if module in patched_modules:
return
def base_getattr(name):
raise AttributeError(
f"module '{module.__name__}' has no attribute '{name}'")
def base_getattr(name):
raise AttributeError(f"module '{module.__name__}' has no attribute '{name}'")
old_getattr = getattr(module, '__getattr__', base_getattr)
old_getattr = getattr(module, '__getattr__', base_getattr)
def new_getattr(name):
if f'_{name}' == func.__name__:
return func()
else:
return old_getattr(name)
def new_getattr(name):
if name in properties:
return properties[name]()
else:
return old_getattr(name)
module.__getattr__ = new_getattr
return func
module.__getattr__ = new_getattr
patched_modules.add(module)
class ModuleProperties:
@staticmethod
def getter(func):
@wraps(func)
def wrapper():
return func()
name = func.__name__
if name.startswith('_'):
properties[name[1:]] = wrapper
else:
raise ValueError("Property function names must start with an underscore")
module = sys.modules[func.__module__]
patch_module(module)
return wrapper
return ModuleProperties()

View File

@ -0,0 +1,10 @@
from __future__ import annotations
import typing
# for nodes that return:
# { "ui" : { "some_field": Any, "other": Any }}
# the outputs dict will be
# (the node id)
# { "1": { "some_field": Any, "other": Any }}
OutputsDict = dict[str, dict[str, typing.Any]]

View File

@ -0,0 +1,13 @@
from __future__ import annotations
from pathlib import PurePosixPath, Path, PosixPath
def construct_path(*args) -> PurePosixPath | Path:
if len(args) > 0 and args[0] is not None and isinstance(args[0], str) and args[0].startswith("/"):
try:
return PosixPath(*args)
except Exception:
return PurePosixPath(*args)
else:
return Path(*args)

View File

@ -8,13 +8,15 @@ from typing import Tuple
from typing_extensions import NotRequired, TypedDict
from .outputs_types import OutputsDict
QueueTuple = Tuple[float, str, dict, dict, list]
MAXIMUM_HISTORY_SIZE = 10000
class TaskInvocation(NamedTuple):
item_id: int | str
outputs: dict
outputs: OutputsDict
status: Optional[ExecutionStatus]

View File

@ -10,11 +10,12 @@ from aio_pika.patterns import JsonRPC
from aiohttp import web
from aiormq import AMQPConnectionError
from .executors import ContextVarExecutor
from .distributed_progress import DistributedExecutorToClientProgress
from .distributed_types import RpcRequest, RpcReply
from .process_pool_executor import ProcessPoolExecutor
from ..client.embedded_comfy_client import EmbeddedComfyClient
from ..cmd.main_pre import tracer
from ..component_model.executor_types import Executor
from ..component_model.queue_types import ExecutionStatus
@ -28,7 +29,7 @@ class DistributedPromptWorker:
queue_name: str = "comfyui",
health_check_port: int = 9090,
loop: Optional[AbstractEventLoop] = None,
executor: Optional[Executor] = None):
executor: Optional[ContextVarExecutor | ProcessPoolExecutor] = None):
self._rpc = None
self._channel = None
self._exit_stack = AsyncExitStack()

View File

@ -0,0 +1,24 @@
import concurrent
import contextvars
import typing
from concurrent.futures import Future, ThreadPoolExecutor
from functools import partial
__version__ = '0.0.1'
from .process_pool_executor import ProcessPoolExecutor
class ContextVarExecutor(ThreadPoolExecutor):
def submit(self, fn: typing.Callable, *args, **kwargs) -> Future:
ctx = contextvars.copy_context() # type: contextvars.Context
return super().submit(partial(ctx.run, partial(fn, *args, **kwargs)))
class ContextVarProcessPoolExecutor(ProcessPoolExecutor):
def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future:
# TODO: serialize the "comfyui_execution_context"
pass

View File

@ -6,6 +6,7 @@ from dataclasses import dataclass, replace
from typing import Optional, Final
from .component_model.executor_types import ExecutorToClientProgress
from .component_model.folder_path_types import FolderNames
from .distributed.server_stub import ServerStub
_current_context: Final[ContextVar] = ContextVar("comfyui_execution_context")
@ -14,23 +15,21 @@ _current_context: Final[ContextVar] = ContextVar("comfyui_execution_context")
@dataclass(frozen=True)
class ExecutionContext:
server: ExecutorToClientProgress
folder_names_and_paths: FolderNames
node_id: Optional[str] = None
task_id: Optional[str] = None
inference_mode: bool = True
_empty_execution_context: Final[ExecutionContext] = ExecutionContext(server=ServerStub())
_current_context.set(ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames()))
def current_execution_context() -> ExecutionContext:
try:
return _current_context.get()
except LookupError:
return _empty_execution_context
return _current_context.get()
@contextmanager
def new_execution_context(ctx: ExecutionContext):
def _new_execution_context(ctx: ExecutionContext):
token = _current_context.set(ctx)
try:
yield ctx
@ -39,8 +38,24 @@ def new_execution_context(ctx: ExecutionContext):
@contextmanager
def context_execute_node(node_id: str, prompt_id: str):
def context_folder_names_and_paths(folder_names_and_paths: FolderNames):
current_ctx = current_execution_context()
new_ctx = replace(current_ctx, node_id=node_id, task_id=prompt_id)
with new_execution_context(new_ctx):
new_ctx = replace(current_ctx, folder_names_and_paths=folder_names_and_paths)
with _new_execution_context(new_ctx):
yield new_ctx
@contextmanager
def context_execute_prompt(server: ExecutorToClientProgress, prompt_id: str, inference_mode: bool = True):
current_ctx = current_execution_context()
new_ctx = replace(current_ctx, server=server, task_id=prompt_id, inference_mode=inference_mode)
with _new_execution_context(new_ctx):
yield new_ctx
@contextmanager
def context_execute_node(node_id: str):
current_ctx = current_execution_context()
new_ctx = replace(current_ctx, node_id=node_id)
with _new_execution_context(new_ctx):
yield new_ctx

View File

@ -198,10 +198,6 @@ Visit the repository, accept the terms, and then do one of the following:
- Login to Hugging Face in your terminal using `huggingface-cli login`
""")
raise exc_info
finally:
# a path was found for any reason, so we should invalidate the cache
if path is not None:
folder_paths.invalidate_cache(folder_name)
if path is None:
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found, and no download candidates matched for the filename.")
return path
@ -504,7 +500,6 @@ def add_known_models(folder_name: str, known_models: Optional[List[Downloadable]
pre_existing = frozenset(known_models)
known_models.extend([model for model in models if model not in pre_existing])
folder_paths.invalidate_cache(folder_name)
return known_models

View File

@ -2,7 +2,6 @@ import torch
import os
import json
import hashlib
import math
import random
import logging
@ -1654,7 +1653,7 @@ class LoadImageMask:
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
return {"required":
{"image": (sorted(files), {"image_upload": True}),
{"image": (natsorted(files), {"image_upload": True}),
"channel": (s._color_channels, ), }
}

View File

@ -29,9 +29,9 @@ import sys
import warnings
from contextlib import contextmanager
from pathlib import Path
from pickle import UnpicklingError
from typing import Optional, Any
import accelerate
import numpy as np
import safetensors.torch
import torch
@ -65,7 +65,7 @@ def load_torch_file(ckpt: str, safe_load=False, device=None):
if ckpt is None:
raise FileNotFoundError("the checkpoint was not found")
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
sd = safetensors.torch.load_file(ckpt, device=device.type)
sd = safetensors.torch.load_file(Path(ckpt).resolve(strict=True), device=device.type)
elif ckpt.lower().endswith("index.json"):
# from accelerate
index_filename = ckpt
@ -81,20 +81,29 @@ def load_torch_file(ckpt: str, safe_load=False, device=None):
for checkpoint_file in checkpoint_files:
sd.update(safetensors.torch.load_file(str(checkpoint_file), device=device.type))
else:
if safe_load:
if not 'weights_only' in torch.load.__code__.co_varnames:
logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
safe_load = False
if safe_load:
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
else:
pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle)
if "global_step" in pl_sd:
logging.debug(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
sd = pl_sd
try:
if safe_load:
if not 'weights_only' in torch.load.__code__.co_varnames:
logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
safe_load = False
if safe_load:
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
else:
pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle)
if "global_step" in pl_sd:
logging.debug(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
sd = pl_sd
except UnpicklingError as exc_info:
try:
# wrong extension is most likely, try to load as safetensors anyway
sd = safetensors.torch.load_file(Path(ckpt).resolve(strict=True), device=device.type)
return sd
except Exception:
exc_info.add_note(f"The checkpoint at {ckpt} could not be loaded as a safetensor nor a torch checkpoint. The file at the path is corrupted or unexpected. Try deleting it and downloading it again")
raise exc_info
return sd

View File

@ -13,7 +13,7 @@ from transformers.models.nllb.tokenization_nllb import \
FAIRSEQ_LANGUAGE_CODES as tokenization_nllb_FAIRSEQ_LANGUAGE_CODES
from comfy.cmd import folder_paths
from comfy.component_model.folder_path_types import SaveImagePathResponse
from comfy.component_model.folder_path_types import SaveImagePathTuple
from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
from comfy.language.language_types import GENERATION_KWARGS_TYPE, GENERATION_KWARGS_TYPE_NAME, TOKENS_TYPE, \
TOKENS_TYPE_NAME, LanguageModel
@ -397,7 +397,7 @@ class SaveString(CustomNode):
OUTPUT_NODE = True
RETURN_TYPES = ()
def get_save_path(self, filename_prefix) -> SaveImagePathResponse:
def get_save_path(self, filename_prefix) -> SaveImagePathTuple:
return folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory(), 0, 0)
def execute(self, value: str | list[str], filename_prefix: str, extension: str = ".json"):

View File

@ -1,21 +0,0 @@
import torch
class TorchCompileModel:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
EXPERIMENTAL = True
def patch(self, model):
m = model.clone()
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model")))
return (m, )
NODE_CLASS_MAPPINGS = {
"TorchCompileModel": TorchCompileModel,
}

10
main.py
View File

@ -2,13 +2,15 @@ import asyncio
import warnings
from pathlib import Path
if __name__ == "__main__":
from comfy.cmd.folder_paths_pre import set_base_path
from comfy.component_model.folder_path_types import FolderNames
if __name__ == "__main__":
warnings.warn("main.py is deprecated. Start comfyui by installing the package through the instructions in the README, not by cloning the repository.", DeprecationWarning)
this_file_parent_dir = Path(__file__).parent
set_base_path(str(this_file_parent_dir))
from comfy.cmd.main import main
from comfy.cmd.folder_paths import folder_names_and_paths # type: FolderNames
fn: FolderNames = folder_names_and_paths
fn.base_paths.clear()
fn.base_paths.append(this_file_parent_dir)
asyncio.run(main(from_script_dir=this_file_parent_dir))

View File

@ -1,6 +1,8 @@
import asyncio
import logging
logging.basicConfig(level=logging.ERROR)
import uuid
from concurrent.futures import ThreadPoolExecutor
from typing import Callable
import jwt
@ -15,10 +17,12 @@ from comfy.component_model.executor_types import Executor
from comfy.component_model.make_mutable import make_mutable
from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation, NamedQueueTuple, ExecutionStatus
from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker
from comfy.distributed.executors import ContextVarExecutor
from comfy.distributed.process_pool_executor import ProcessPoolExecutor
from comfy.distributed.server_stub import ServerStub
def create_test_prompt() -> QueueItem:
from comfy.cmd.execution import validate_prompt
@ -37,8 +41,11 @@ async def test_sign_jwt_auth_none():
assert user_token["sub"] == client_id
_executor_factories: tuple[Executor] = (ContextVarExecutor,)
@pytest.mark.asyncio
@pytest.mark.parametrize("executor_factory", (ThreadPoolExecutor, ProcessPoolExecutor,))
@pytest.mark.parametrize("executor_factory", _executor_factories)
async def test_basic_queue_worker(executor_factory: Callable[..., Executor]) -> None:
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params()
@ -72,7 +79,7 @@ async def test_distributed_prompt_queues_same_process():
frontend.put(test_prompt)
# start a worker thread
thread_pool = ThreadPoolExecutor(max_workers=1)
thread_pool = ContextVarExecutor(max_workers=1)
async def in_thread():
incoming, incoming_prompt_id = worker.get()
@ -127,7 +134,7 @@ async def check_health(url: str, max_retries: int = 5, retry_delay: float = 1.0)
@pytest.mark.asyncio
@pytest.mark.parametrize("executor_factory", (ThreadPoolExecutor, ProcessPoolExecutor,))
@pytest.mark.parametrize("executor_factory", _executor_factories)
async def test_basic_queue_worker_with_health_check(executor_factory):
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params()

View File

@ -1,3 +1,4 @@
import logging
import uuid
from contextvars import ContextVar
from typing import Dict, Optional
@ -73,7 +74,8 @@ class ComfyClient:
prompt_id = str(uuid.uuid4())
try:
outputs = await self.embedded_client.queue_prompt(graph.finalize(), prompt_id=prompt_id)
except (RuntimeError, DependencyCycleError):
except (RuntimeError, DependencyCycleError) as exc_info:
logging.warning("error when queueing prompt", exc_info=exc_info)
outputs = {}
result = RunResult(prompt_id=prompt_id)
result.outputs = outputs

View File

@ -2,11 +2,13 @@
# TODO(yoland): clean up this after I get back down
import os
import tempfile
from unittest.mock import patch
from pathlib import Path
import pytest
from comfy.cmd import folder_paths
from comfy.component_model.folder_path_types import FolderNames, ModelPaths
from comfy.execution_context import context_folder_names_and_paths
@pytest.fixture
@ -40,16 +42,6 @@ def test_add_model_folder_path():
assert "/test/path" in folder_paths.get_folder_paths("test_folder")
def test_recursive_search(temp_dir):
os.makedirs(os.path.join(temp_dir, "subdir"))
open(os.path.join(temp_dir, "file1.txt"), "w").close()
open(os.path.join(temp_dir, "subdir", "file2.txt"), "w").close()
files, dirs = folder_paths.recursive_search(temp_dir)
assert set(files) == {"file1.txt", os.path.join("subdir", "file2.txt")}
assert len(dirs) == 2 # temp_dir and subdir
def test_filter_files_extensions():
files = ["file1.txt", "file2.jpg", "file3.png", "file4.txt"]
assert folder_paths.filter_files_extensions(files, [".txt"]) == ["file1.txt", "file4.txt"]
@ -57,16 +49,24 @@ def test_filter_files_extensions():
assert folder_paths.filter_files_extensions(files, []) == files
@patch("folder_paths.recursive_search")
@patch("folder_paths.folder_names_and_paths")
def test_get_filename_list(mock_folder_names_and_paths, mock_recursive_search):
mock_folder_names_and_paths.__getitem__.return_value = (["/test/path"], {".txt"})
mock_recursive_search.return_value = (["file1.txt", "file2.jpg"], {})
assert folder_paths.get_filename_list("test_folder") == ["file1.txt"]
def test_get_filename_list(temp_dir):
base_path = Path(temp_dir)
fn = FolderNames(base_paths=[base_path])
rel_path = Path("test/path")
fn.add(ModelPaths(["test_folder"], additional_relative_directory_paths={rel_path}, supported_extensions={".txt"}))
dir_path = base_path / rel_path
Path.mkdir(dir_path, parents=True, exist_ok=True)
files = ["file1.txt", "file2.jpg"]
for file in files:
Path.touch(dir_path / file, exist_ok=True)
with context_folder_names_and_paths(fn):
assert folder_paths.get_filename_list("test_folder") == ["file1.txt"]
def test_get_save_image_path(temp_dir):
with patch("folder_paths.output_directory", temp_dir):
with context_folder_names_and_paths(FolderNames(base_paths=[Path(temp_dir)])):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("test", temp_dir, 100, 100)
assert os.path.samefile(full_output_folder, temp_dir)
assert filename == "test"