mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
--base-paths argument adds additional paths to search for models/checkpoints, models/loras, etc. directories, including directories specified in this pattern by custom nodes
This commit is contained in:
parent
89d07f3adf
commit
4a13766d14
101
README.md
101
README.md
@ -722,19 +722,31 @@ You can pass additional extra model path configurations with one or more copies
|
||||
### Command Line Arguments
|
||||
|
||||
```
|
||||
usage: comfyui.exe [-h] [-c CONFIG_FILE] [--write-out-config-file CONFIG_OUTPUT_PATH] [-w CWD] [-H [IP]] [--port PORT] [--enable-cors-header [ORIGIN]] [--max-upload-size MAX_UPLOAD_SIZE]
|
||||
[--extra-model-paths-config PATH [PATH ...]] [--output-directory OUTPUT_DIRECTORY] [--temp-directory TEMP_DIRECTORY] [--input-directory INPUT_DIRECTORY] [--auto-launch]
|
||||
[--disable-auto-launch] [--cuda-device DEVICE_ID] [--cuda-malloc | --disable-cuda-malloc] [--force-fp32 | --force-fp16 | --force-bf16]
|
||||
usage: comfyui.exe [-h] [-c CONFIG_FILE] [--write-out-config-file CONFIG_OUTPUT_PATH] [-w CWD] [--base-paths BASE_PATHS [BASE_PATHS ...]] [-H [IP]] [--port PORT]
|
||||
[--enable-cors-header [ORIGIN]] [--max-upload-size MAX_UPLOAD_SIZE] [--extra-model-paths-config PATH [PATH ...]]
|
||||
[--output-directory OUTPUT_DIRECTORY] [--temp-directory TEMP_DIRECTORY] [--input-directory INPUT_DIRECTORY] [--auto-launch] [--disable-auto-launch]
|
||||
[--cuda-device DEVICE_ID] [--cuda-malloc | --disable-cuda-malloc] [--force-fp32 | --force-fp16 | --force-bf16]
|
||||
[--bf16-unet | --fp16-unet | --fp8_e4m3fn-unet | --fp8_e5m2-unet] [--fp16-vae | --fp32-vae | --bf16-vae] [--cpu-vae]
|
||||
[--fp8_e4m3fn-text-enc | --fp8_e5m2-text-enc | --fp16-text-enc | --fp32-text-enc] [--directml [DIRECTML_DEVICE]] [--disable-ipex-optimize]
|
||||
[--preview-method [none,auto,latent2rgb,taesd]] [--use-split-cross-attention | --use-quad-cross-attention | --use-pytorch-cross-attention] [--disable-xformers]
|
||||
[--force-upcast-attention | --dont-upcast-attention] [--gpu-only | --highvram | --normalvram | --lowvram | --novram | --cpu] [--disable-smart-memory] [--deterministic]
|
||||
[--dont-print-server] [--quick-test-for-ci] [--windows-standalone-build] [--disable-metadata] [--multi-user] [--create-directories]
|
||||
[--plausible-analytics-base-url PLAUSIBLE_ANALYTICS_BASE_URL] [--plausible-analytics-domain PLAUSIBLE_ANALYTICS_DOMAIN] [--analytics-use-identity-provider]
|
||||
[--distributed-queue-connection-uri DISTRIBUTED_QUEUE_CONNECTION_URI] [--distributed-queue-worker] [--distributed-queue-frontend] [--distributed-queue-name DISTRIBUTED_QUEUE_NAME]
|
||||
[--external-address EXTERNAL_ADDRESS] [--verbose] [--disable-known-models] [--max-queue-size MAX_QUEUE_SIZE] [--otel-service-name OTEL_SERVICE_NAME]
|
||||
[--otel-service-version OTEL_SERVICE_VERSION] [--otel-exporter-otlp-endpoint OTEL_EXPORTER_OTLP_ENDPOINT]
|
||||
|
||||
[--preview-method [none,auto,latent2rgb,taesd]] [--preview-size PREVIEW_SIZE] [--cache-lru CACHE_LRU]
|
||||
[--use-split-cross-attention | --use-quad-cross-attention | --use-pytorch-cross-attention] [--disable-xformers] [--disable-flash-attn]
|
||||
[--disable-sage-attention] [--force-upcast-attention | --dont-upcast-attention]
|
||||
[--gpu-only | --highvram | --normalvram | --lowvram | --novram | --cpu] [--reserve-vram RESERVE_VRAM]
|
||||
[--default-hashing-function {md5,sha1,sha256,sha512}] [--disable-smart-memory] [--deterministic] [--fast] [--dont-print-server]
|
||||
[--quick-test-for-ci] [--windows-standalone-build] [--disable-metadata] [--disable-all-custom-nodes] [--multi-user] [--create-directories]
|
||||
[--plausible-analytics-base-url PLAUSIBLE_ANALYTICS_BASE_URL] [--plausible-analytics-domain PLAUSIBLE_ANALYTICS_DOMAIN]
|
||||
[--analytics-use-identity-provider] [--distributed-queue-connection-uri DISTRIBUTED_QUEUE_CONNECTION_URI] [--distributed-queue-worker]
|
||||
[--distributed-queue-frontend] [--distributed-queue-name DISTRIBUTED_QUEUE_NAME] [--external-address EXTERNAL_ADDRESS]
|
||||
[--logging-level {DEBUG,INFO,WARNING,ERROR,CRITICAL}] [--disable-known-models] [--max-queue-size MAX_QUEUE_SIZE]
|
||||
[--otel-service-name OTEL_SERVICE_NAME] [--otel-service-version OTEL_SERVICE_VERSION] [--otel-exporter-otlp-endpoint OTEL_EXPORTER_OTLP_ENDPOINT]
|
||||
[--force-channels-last] [--force-hf-local-dir-mode] [--front-end-version FRONT_END_VERSION] [--front-end-root FRONT_END_ROOT]
|
||||
[--executor-factory EXECUTOR_FACTORY] [--openai-api-key OPENAI_API_KEY] [--user-directory USER_DIRECTORY] [--blip-model-url BLIP_MODEL_URL]
|
||||
[--blip-model-vqa-url BLIP_MODEL_VQA_URL] [--sam-model-vith-url SAM_MODEL_VITH_URL] [--sam-model-vitl-url SAM_MODEL_VITL_URL]
|
||||
[--sam-model-vitb-url SAM_MODEL_VITB_URL] [--history-display-limit HISTORY_DISPLAY_LIMIT] [--ffmpeg-bin-path FFMPEG_BIN_PATH]
|
||||
[--ffmpeg-extra-codecs FFMPEG_EXTRA_CODECS] [--wildcards-path WILDCARDS_PATH] [--wildcard-api WILDCARD_API] [--photoprism-host PHOTOPRISM_HOST]
|
||||
[--immich-host IMMICH_HOST] [--ideogram-session-cookie IDEOGRAM_SESSION_COOKIE] [--annotator-ckpts-path ANNOTATOR_CKPTS_PATH] [--use-symlinks]
|
||||
[--ort-providers ORT_PROVIDERS] [--vfi-ops-backend VFI_OPS_BACKEND] [--dependency-version DEPENDENCY_VERSION] [--mmdet-skip] [--sam-editor-cpu]
|
||||
[--sam-editor-model SAM_EDITOR_MODEL] [--custom-wildcards CUSTOM_WILDCARDS] [--disable-gpu-opencv]
|
||||
|
||||
options:
|
||||
-h, --help show this help message and exit
|
||||
@ -742,10 +754,14 @@ options:
|
||||
config file path
|
||||
--write-out-config-file CONFIG_OUTPUT_PATH
|
||||
takes the current command line args and writes them out to a config file at the given path, then exits
|
||||
-w CWD, --cwd CWD Specify the working directory. If not set, this is the current working directory. models/, input/, output/ and other directories will be located here by default. [env var:
|
||||
COMFYUI_CWD]
|
||||
-w CWD, --cwd CWD Specify the working directory. If not set, this is the current working directory. models/, input/, output/ and other directories will be
|
||||
located here by default. [env var: COMFYUI_CWD]
|
||||
--base-paths BASE_PATHS [BASE_PATHS ...]
|
||||
Additional base paths for custom nodes, models and inputs. [env var: COMFYUI_BASE_PATHS]
|
||||
-H [IP], --listen [IP]
|
||||
Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all) [env var: COMFYUI_LISTEN]
|
||||
Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like:
|
||||
127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6) [env var:
|
||||
COMFYUI_LISTEN]
|
||||
--port PORT Set the listen port. [env var: COMFYUI_PORT]
|
||||
--enable-cors-header [ORIGIN]
|
||||
Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'. [env var: COMFYUI_ENABLE_CORS_HEADER]
|
||||
@ -789,6 +805,10 @@ options:
|
||||
Disables ipex.optimize when loading models with Intel GPUs. [env var: COMFYUI_DISABLE_IPEX_OPTIMIZE]
|
||||
--preview-method [none,auto,latent2rgb,taesd]
|
||||
Default preview method for sampler nodes. [env var: COMFYUI_PREVIEW_METHOD]
|
||||
--preview-size PREVIEW_SIZE
|
||||
Sets the maximum preview size for sampler nodes. [env var: COMFYUI_PREVIEW_SIZE]
|
||||
--cache-lru CACHE_LRU
|
||||
Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM. [env var: COMFYUI_CACHE_LRU]
|
||||
--use-split-cross-attention
|
||||
Use the split cross attention optimization. Ignored when xformers is used. [env var: COMFYUI_USE_SPLIT_CROSS_ATTENTION]
|
||||
--use-quad-cross-attention
|
||||
@ -796,6 +816,9 @@ options:
|
||||
--use-pytorch-cross-attention
|
||||
Use the new pytorch 2.0 cross attention function. [env var: COMFYUI_USE_PYTORCH_CROSS_ATTENTION]
|
||||
--disable-xformers Disable xformers. [env var: COMFYUI_DISABLE_XFORMERS]
|
||||
--disable-flash-attn Disable Flash Attention [env var: COMFYUI_DISABLE_FLASH_ATTN]
|
||||
--disable-sage-attention
|
||||
Disable Sage Attention [env var: COMFYUI_DISABLE_SAGE_ATTENTION]
|
||||
--force-upcast-attention
|
||||
Force enable attention upcasting, please report if it fixes black images. [env var: COMFYUI_FORCE_UPCAST_ATTENTION]
|
||||
--dont-upcast-attention
|
||||
@ -806,15 +829,25 @@ options:
|
||||
--lowvram Split the unet in parts to use less vram. [env var: COMFYUI_LOWVRAM]
|
||||
--novram When lowvram isn't enough. [env var: COMFYUI_NOVRAM]
|
||||
--cpu To use the CPU for everything (slow). [env var: COMFYUI_CPU]
|
||||
--reserve-vram RESERVE_VRAM
|
||||
Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.
|
||||
[env var: COMFYUI_RESERVE_VRAM]
|
||||
--default-hashing-function {md5,sha1,sha256,sha512}
|
||||
Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256. [env var:
|
||||
COMFYUI_DEFAULT_HASHING_FUNCTION]
|
||||
--disable-smart-memory
|
||||
Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can. [env var: COMFYUI_DISABLE_SMART_MEMORY]
|
||||
--deterministic Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases. [env var: COMFYUI_DETERMINISTIC]
|
||||
--deterministic Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases. [env var:
|
||||
COMFYUI_DETERMINISTIC]
|
||||
--fast Enable some untested and potentially quality deteriorating optimizations. [env var: COMFYUI_FAST]
|
||||
--dont-print-server Don't print server output. [env var: COMFYUI_DONT_PRINT_SERVER]
|
||||
--quick-test-for-ci Quick test for CI. Raises an error if nodes cannot be imported, [env var: COMFYUI_QUICK_TEST_FOR_CI]
|
||||
--windows-standalone-build
|
||||
Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup). [env var:
|
||||
COMFYUI_WINDOWS_STANDALONE_BUILD]
|
||||
Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening
|
||||
the page on startup). [env var: COMFYUI_WINDOWS_STANDALONE_BUILD]
|
||||
--disable-metadata Disable saving prompt metadata in files. [env var: COMFYUI_DISABLE_METADATA]
|
||||
--disable-all-custom-nodes
|
||||
Disable loading all custom nodes. [env var: COMFYUI_DISABLE_ALL_CUSTOM_NODES]
|
||||
--multi-user Enables per-user storage. [env var: COMFYUI_MULTI_USER]
|
||||
--create-directories Creates the default models/, input/, output/ and temp/ directories, then exits. [env var: COMFYUI_CREATE_DIRECTORIES]
|
||||
--plausible-analytics-base-url PLAUSIBLE_ANALYTICS_BASE_URL
|
||||
@ -824,18 +857,19 @@ options:
|
||||
--analytics-use-identity-provider
|
||||
Uses platform identifiers for unique visitor analytics. [env var: COMFYUI_ANALYTICS_USE_IDENTITY_PROVIDER]
|
||||
--distributed-queue-connection-uri DISTRIBUTED_QUEUE_CONNECTION_URI
|
||||
EXAMPLE: "amqp://guest:guest@127.0.0.1" - Servers and clients will connect to this AMPQ URL to form a distributed queue and exchange prompt execution requests and progress
|
||||
updates. [env var: COMFYUI_DISTRIBUTED_QUEUE_CONNECTION_URI]
|
||||
EXAMPLE: "amqp://guest:guest@127.0.0.1" - Servers and clients will connect to this AMPQ URL to form a distributed queue and exchange prompt
|
||||
execution requests and progress updates. [env var: COMFYUI_DISTRIBUTED_QUEUE_CONNECTION_URI]
|
||||
--distributed-queue-worker
|
||||
Workers will pull requests off the AMQP URL. [env var: COMFYUI_DISTRIBUTED_QUEUE_WORKER]
|
||||
--distributed-queue-frontend
|
||||
Frontends will start the web UI and connect to the provided AMQP URL to submit prompts. [env var: COMFYUI_DISTRIBUTED_QUEUE_FRONTEND]
|
||||
--distributed-queue-name DISTRIBUTED_QUEUE_NAME
|
||||
This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the
|
||||
user ID [env var: COMFYUI_DISTRIBUTED_QUEUE_NAME]
|
||||
This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue
|
||||
name, followed by a '.', then the user ID [env var: COMFYUI_DISTRIBUTED_QUEUE_NAME]
|
||||
--external-address EXTERNAL_ADDRESS
|
||||
Specifies a base URL for external addresses reported by the API, such as for image paths. [env var: COMFYUI_EXTERNAL_ADDRESS]
|
||||
--verbose Enables more debug prints. [env var: COMFYUI_VERBOSE]
|
||||
--logging-level {DEBUG,INFO,WARNING,ERROR,CRITICAL}
|
||||
Set the logging level [env var: COMFYUI_LOGGING_LEVEL]
|
||||
--disable-known-models
|
||||
Disables automatic downloads of known models and prevents them from appearing in the UI. [env var: COMFYUI_DISABLE_KNOWN_MODELS]
|
||||
--max-queue-size MAX_QUEUE_SIZE
|
||||
@ -845,8 +879,27 @@ options:
|
||||
--otel-service-version OTEL_SERVICE_VERSION
|
||||
The version of the service or application that is generating telemetry data. [env var: OTEL_SERVICE_VERSION]
|
||||
--otel-exporter-otlp-endpoint OTEL_EXPORTER_OTLP_ENDPOINT
|
||||
A base endpoint URL for any signal type, with an optionally-specified port number. Helpful for when you're sending more than one signal to the same endpoint and want one
|
||||
environment variable to control the endpoint. [env var: OTEL_EXPORTER_OTLP_ENDPOINT]
|
||||
A base endpoint URL for any signal type, with an optionally-specified port number. Helpful for when you're sending more than one signal to the
|
||||
same endpoint and want one environment variable to control the endpoint. [env var: OTEL_EXPORTER_OTLP_ENDPOINT]
|
||||
--force-channels-last
|
||||
Force channels last format when inferencing the models. [env var: COMFYUI_FORCE_CHANNELS_LAST]
|
||||
--force-hf-local-dir-mode
|
||||
Download repos from huggingface.co to the models/huggingface directory with the "local_dir" argument instead of models/huggingface_cache with
|
||||
the "cache_dir" argument, recreating the traditional file structure. [env var: COMFYUI_FORCE_HF_LOCAL_DIR_MODE]
|
||||
--front-end-version FRONT_END_VERSION
|
||||
Specifies the version of the frontend to be used. This command needs internet connectivity to query and download available frontend
|
||||
implementations from GitHub releases. The version string should be in the format of: [repoOwner]/[repoName]@[version] where version is one of:
|
||||
"latest" or a valid version number (e.g. "1.0.0") [env var: COMFYUI_FRONT_END_VERSION]
|
||||
--front-end-root FRONT_END_ROOT
|
||||
The local filesystem path to the directory where the frontend is located. Overrides --front-end-version. [env var: COMFYUI_FRONT_END_ROOT]
|
||||
--executor-factory EXECUTOR_FACTORY
|
||||
When running ComfyUI as a distributed worker, this specifies the kind of executor that should be used to run the actual ComfyUI workflow
|
||||
worker. A ThreadPoolExecutor is the default. A ProcessPoolExecutor results in better memory management, since the process will be closed and
|
||||
large, contiguous blocks of CUDA memory can be freed. [env var: COMFYUI_EXECUTOR_FACTORY]
|
||||
--openai-api-key OPENAI_API_KEY
|
||||
Configures the OpenAI API Key for the OpenAI nodes [env var: OPENAI_API_KEY]
|
||||
--user-directory USER_DIRECTORY
|
||||
Set the ComfyUI user directory with an absolute path. [env var: COMFYUI_USER_DIRECTORY]
|
||||
|
||||
Args that start with '--' can also be set in a config file (config.yaml or config.json or specified via -c). Config file syntax allows: key=value, flag=true, stuff=[a,b,c] (for details, see syntax at
|
||||
https://goo.gl/R74nmi). In general, command-line values override environment variables which override config file values which override defaults.
|
||||
|
||||
@ -28,6 +28,7 @@ def _create_parser() -> EnhancedConfigArgParser:
|
||||
|
||||
parser.add_argument('-w', "--cwd", type=str, default=None,
|
||||
help="Specify the working directory. If not set, this is the current working directory. models/, input/, output/ and other directories will be located here by default.")
|
||||
parser.add_argument("--base-paths", type=str, nargs='+', default=[], help="Additional base paths for custom nodes, models and inputs.")
|
||||
parser.add_argument('-H', "--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::",
|
||||
help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
|
||||
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
||||
@ -161,7 +162,7 @@ def _create_parser() -> EnhancedConfigArgParser:
|
||||
help="This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID")
|
||||
parser.add_argument("--external-address", required=False,
|
||||
help="Specifies a base URL for external addresses reported by the API, such as for image paths.")
|
||||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||
parser.add_argument("--logging-level", type=str, default='INFO', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||
parser.add_argument("--disable-known-models", action="store_true", help="Disables automatic downloads of known models and prevents them from appearing in the UI.")
|
||||
parser.add_argument("--max-queue-size", type=int, default=65536, help="The API will reject prompt requests if the queue's size exceeds this value.")
|
||||
# tracing
|
||||
@ -251,11 +252,6 @@ def _parse_args(parser: Optional[argparse.ArgumentParser] = None, args_parsing:
|
||||
if args.disable_auto_launch:
|
||||
args.auto_launch = False
|
||||
|
||||
logging_level = logging.INFO
|
||||
if args.verbose:
|
||||
logging_level = logging.DEBUG
|
||||
|
||||
logging.basicConfig(format="%(message)s", level=logging_level)
|
||||
configuration_obj = Configuration(**vars(args))
|
||||
configuration_obj.config_files = config_files
|
||||
assert all(isinstance(config_file, str) for config_file in config_files)
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import enum
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Callable, Any, Union, Mapping, NamedTuple
|
||||
|
||||
import configargparse
|
||||
@ -36,15 +38,16 @@ class Configuration(dict):
|
||||
|
||||
Attributes:
|
||||
config_files (Optional[List[str]]): Path to the configuration file(s) that were set in the arguments.
|
||||
cwd (Optional[str]): Working directory. Defaults to the current directory.
|
||||
cwd (Optional[str]): Working directory. Defaults to the current directory. This is always treated as a base path for model files, and it will be the place where model files are downloaded.
|
||||
base_paths (Optional[list[str]]): Additional base paths for custom nodes, models and inputs.
|
||||
listen (str): IP address to listen on. Defaults to "127.0.0.1".
|
||||
port (int): Port number for the server to listen on. Defaults to 8188.
|
||||
enable_cors_header (Optional[str]): Enables CORS with the specified origin.
|
||||
max_upload_size (float): Maximum upload size in MB. Defaults to 100.
|
||||
extra_model_paths_config (Optional[List[str]]): Extra model paths configuration files.
|
||||
output_directory (Optional[str]): Directory for output files.
|
||||
output_directory (Optional[str]): Directory for output files. This can also be a relative path to the cwd or current working directory.
|
||||
temp_directory (Optional[str]): Temporary directory for processing.
|
||||
input_directory (Optional[str]): Directory for input files.
|
||||
input_directory (Optional[str]): Directory for input files. When this is a relative path, it will be looked up relative to the cwd (current working directory) and all of the base_paths.
|
||||
auto_launch (bool): Auto-launch UI in the default browser. Defaults to False.
|
||||
disable_auto_launch (bool): Disable auto-launching the browser.
|
||||
cuda_device (Optional[int]): CUDA device ID. None means default device.
|
||||
@ -87,7 +90,6 @@ class Configuration(dict):
|
||||
reserve_vram (Optional[float]): Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS
|
||||
disable_smart_memory (bool): Disable smart memory management.
|
||||
deterministic (bool): Use deterministic algorithms where possible.
|
||||
dont_print_server (bool): Suppress server output.
|
||||
quick_test_for_ci (bool): Enable quick testing mode for CI.
|
||||
windows_standalone_build (bool): Enable features for standalone Windows build.
|
||||
disable_metadata (bool): Disable saving metadata with outputs.
|
||||
@ -103,7 +105,7 @@ class Configuration(dict):
|
||||
distributed_queue_worker (bool): Workers will pull requests off the AMQP URL.
|
||||
distributed_queue_name (str): This name will be used by the frontends and workers to exchange prompt requests and replies. Progress updates will be prefixed by the queue name, followed by a '.', then the user ID.
|
||||
external_address (str): Specifies a base URL for external addresses reported by the API, such as for image paths.
|
||||
verbose (bool | str): Shows extra output for debugging purposes such as import errors of custom nodes; or, specifies a log level
|
||||
logging_level (str): Specifies a log level
|
||||
disable_known_models (bool): Disables automatic downloads of known models and prevents them from appearing in the UI.
|
||||
max_queue_size (int): The API will reject prompt requests if the queue's size exceeds this value.
|
||||
otel_service_name (str): The name of the service or application that is generating telemetry data. Default: "comfyui".
|
||||
@ -122,6 +124,7 @@ class Configuration(dict):
|
||||
self._observers: List[ConfigObserver] = []
|
||||
self.config_files = []
|
||||
self.cwd: Optional[str] = None
|
||||
self.base_paths: list[Path] = []
|
||||
self.listen: str = "127.0.0.1"
|
||||
self.port: int = 8188
|
||||
self.enable_cors_header: Optional[str] = None
|
||||
@ -192,7 +195,7 @@ class Configuration(dict):
|
||||
self.force_channels_last: bool = False
|
||||
self.force_hf_local_dir_mode = False
|
||||
self.preview_size: int = 512
|
||||
self.verbose: str | bool = "INFO"
|
||||
self.logging_level: str = "INFO"
|
||||
|
||||
# from guill
|
||||
self.cache_lru: int = 0
|
||||
@ -253,6 +256,17 @@ class Configuration(dict):
|
||||
self.update(state)
|
||||
self._observers = []
|
||||
|
||||
@property
|
||||
def verbose(self) -> str:
|
||||
return self.logging_level
|
||||
|
||||
@verbose.setter
|
||||
def verbose(self, value):
|
||||
if isinstance(value, bool):
|
||||
self.logging_level = "DEBUG"
|
||||
else:
|
||||
self.logging_level = value
|
||||
|
||||
|
||||
class EnumAction(argparse.Action):
|
||||
"""
|
||||
|
||||
@ -1,13 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
from asyncio import get_event_loop
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from multiprocessing import RLock
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from opentelemetry import context, propagate
|
||||
@ -17,13 +18,16 @@ from opentelemetry.trace import Status, StatusCode
|
||||
from .client_types import V1QueuePromptResponse
|
||||
from ..api.components.schema.prompt import PromptDict
|
||||
from ..cli_args_types import Configuration
|
||||
from ..cmd.folder_paths import init_default_paths
|
||||
from ..cmd.main_pre import tracer
|
||||
from ..component_model.executor_types import ExecutorToClientProgress, Executor
|
||||
from ..component_model.executor_types import ExecutorToClientProgress
|
||||
from ..component_model.make_mutable import make_mutable
|
||||
from ..distributed.executors import ContextVarExecutor
|
||||
from ..distributed.process_pool_executor import ProcessPoolExecutor
|
||||
from ..distributed.server_stub import ServerStub
|
||||
from ..execution_context import current_execution_context
|
||||
|
||||
_prompt_executor = contextvars.ContextVar('prompt_executor')
|
||||
_prompt_executor = threading.local()
|
||||
|
||||
|
||||
def _execute_prompt(
|
||||
@ -33,6 +37,9 @@ def _execute_prompt(
|
||||
span_context: dict,
|
||||
progress_handler: ExecutorToClientProgress | None,
|
||||
configuration: Configuration | None) -> dict:
|
||||
execution_context = current_execution_context()
|
||||
if len(execution_context.folder_names_and_paths) == 0 or configuration is not None:
|
||||
init_default_paths(execution_context.folder_names_and_paths, configuration)
|
||||
span_context: Context = propagate.extract(span_context)
|
||||
token = attach(span_context)
|
||||
try:
|
||||
@ -52,8 +59,8 @@ def __execute_prompt(
|
||||
progress_handler = progress_handler or ServerStub()
|
||||
|
||||
try:
|
||||
prompt_executor = _prompt_executor.get()
|
||||
except LookupError:
|
||||
prompt_executor: PromptExecutor = _prompt_executor.executor
|
||||
except (LookupError, AttributeError):
|
||||
if configuration is None:
|
||||
options.enable_args_parsing()
|
||||
else:
|
||||
@ -65,7 +72,7 @@ def __execute_prompt(
|
||||
with tracer.start_as_current_span("Initialize Prompt Executor", context=span_context) as span:
|
||||
prompt_executor = PromptExecutor(progress_handler, lru_size=configuration.cache_lru if configuration is not None else 0)
|
||||
prompt_executor.raise_exceptions = True
|
||||
_prompt_executor.set(prompt_executor)
|
||||
_prompt_executor.executor = prompt_executor
|
||||
|
||||
with tracer.start_as_current_span("Execute Prompt", context=span_context) as span:
|
||||
try:
|
||||
@ -96,6 +103,13 @@ def __execute_prompt(
|
||||
|
||||
|
||||
def _cleanup():
|
||||
from ..cmd.execution import PromptExecutor
|
||||
try:
|
||||
prompt_executor: PromptExecutor = _prompt_executor.executor
|
||||
# this should clear all references to output tensors and make it easier to collect back the memory
|
||||
prompt_executor.reset()
|
||||
except (LookupError, AttributeError):
|
||||
pass
|
||||
from .. import model_management
|
||||
model_management.unload_all_models()
|
||||
gc.collect()
|
||||
@ -139,9 +153,9 @@ class EmbeddedComfyClient:
|
||||
In order to use this in blocking methods, learn more about asyncio online.
|
||||
"""
|
||||
|
||||
def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: Executor = None):
|
||||
def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: ProcessPoolExecutor | ContextVarExecutor = None):
|
||||
self._progress_handler = progress_handler or ServerStub()
|
||||
self._executor = executor or ThreadPoolExecutor(max_workers=max_workers)
|
||||
self._executor = executor or ContextVarExecutor(max_workers=max_workers)
|
||||
self._configuration = configuration
|
||||
self._is_running = False
|
||||
self._task_count_lock = RLock()
|
||||
|
||||
@ -27,7 +27,7 @@ from ..component_model.executor_types import ExecutorToClientProgress, Validatio
|
||||
HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage
|
||||
from ..component_model.files import canonicalize_path
|
||||
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
|
||||
from ..execution_context import new_execution_context, context_execute_node, ExecutionContext
|
||||
from ..execution_context import context_execute_node, context_execute_prompt
|
||||
from ..nodes.package import import_all_nodes_in_workspace
|
||||
from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
|
||||
|
||||
@ -77,24 +77,19 @@ class IsChangedCache:
|
||||
class CacheSet:
|
||||
def __init__(self, lru_size=None):
|
||||
if lru_size is None or lru_size == 0:
|
||||
self.init_classic_cache()
|
||||
# Performs like the old cache -- dump data ASAP
|
||||
|
||||
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
|
||||
self.ui = HierarchicalCache(CacheKeySetInputSignature)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
else:
|
||||
self.init_lru_cache(lru_size)
|
||||
# Useful for those with ample RAM/VRAM -- allows experimenting without
|
||||
# blowing away the cache every time
|
||||
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=lru_size)
|
||||
self.ui = LRUCache(CacheKeySetInputSignature, max_size=lru_size)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
self.all = [self.outputs, self.ui, self.objects]
|
||||
|
||||
# Useful for those with ample RAM/VRAM -- allows experimenting without
|
||||
# blowing away the cache every time
|
||||
def init_lru_cache(self, cache_size):
|
||||
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
# Performs like the old cache -- dump data ASAP
|
||||
def init_classic_cache(self):
|
||||
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
|
||||
self.ui = HierarchicalCache(CacheKeySetInputSignature)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
def recursive_debug_dump(self):
|
||||
result = {
|
||||
"outputs": self.outputs.recursive_debug_dump(),
|
||||
@ -308,11 +303,11 @@ def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, caches,
|
||||
:param pending_subgraph_results:
|
||||
:return:
|
||||
"""
|
||||
with context_execute_node(_node_id, prompt_id):
|
||||
with context_execute_node(_node_id):
|
||||
return _execute(server, dynprompt, caches, _node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
|
||||
|
||||
|
||||
def _execute(server, dynprompt, caches, current_item: str, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) -> RecursiveExecutionTuple:
|
||||
def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) -> RecursiveExecutionTuple:
|
||||
unique_id = current_item
|
||||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||
display_node_id = dynprompt.get_display_node_id(unique_id)
|
||||
@ -548,7 +543,7 @@ class PromptExecutor:
|
||||
# torchao and potentially other optimization approaches break when the models are created in inference mode
|
||||
# todo: this should really be backpropagated to code which creates ModelPatchers via lazy evaluation rather than globally checked here
|
||||
inference_mode = all(not hasattr(node_class, "INFERENCE_MODE") or node_class.INFERENCE_MODE for node_class in iterate_obj_classes(prompt))
|
||||
with new_execution_context(ExecutionContext(self.server, task_id=prompt_id, inference_mode=inference_mode)):
|
||||
with context_execute_prompt(self.server, prompt_id, inference_mode=inference_mode):
|
||||
self._execute_inner(prompt, prompt_id, extra_data, execute_outputs)
|
||||
|
||||
def _execute_inner(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None, inference_mode: bool = True):
|
||||
|
||||
@ -4,142 +4,144 @@ import logging
|
||||
import mimetypes
|
||||
import os
|
||||
import time
|
||||
from typing import Optional, List, Final, Literal
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Literal
|
||||
|
||||
from .folder_paths_pre import get_base_path
|
||||
from ..cli_args_types import Configuration
|
||||
from ..component_model.deprecation import _deprecate_method
|
||||
from ..component_model.files import get_package_as_path
|
||||
from ..component_model.folder_path_types import FolderPathsTuple, FolderNames, SaveImagePathResponse
|
||||
from ..component_model.folder_path_types import extension_mimetypes_cache as _extension_mimetypes_cache
|
||||
from ..component_model.folder_path_types import supported_pt_extensions as _supported_pt_extensions
|
||||
from ..component_model.module_property import module_property
|
||||
from ..component_model.folder_path_types import FolderNames, SaveImagePathTuple, ModelPaths
|
||||
from ..component_model.folder_path_types import supported_pt_extensions, extension_mimetypes_cache
|
||||
from ..component_model.module_property import create_module_properties
|
||||
from ..component_model.platform_path import construct_path
|
||||
from ..execution_context import current_execution_context
|
||||
|
||||
supported_pt_extensions: Final[frozenset[str]] = _supported_pt_extensions
|
||||
extension_mimetypes_cache: Final[dict[str, str]] = _extension_mimetypes_cache
|
||||
_module_properties = create_module_properties()
|
||||
|
||||
|
||||
@_module_properties.getter
|
||||
def _supported_pt_extensions() -> frozenset[str]:
|
||||
return supported_pt_extensions
|
||||
|
||||
|
||||
@_module_properties.getter
|
||||
def _extension_mimetypes_cache() -> dict[str, str]:
|
||||
return extension_mimetypes_cache
|
||||
|
||||
|
||||
# todo: this needs to be wrapped in a context and configurable
|
||||
@module_property
|
||||
@_module_properties.getter
|
||||
def _base_path():
|
||||
return get_base_path()
|
||||
return _folder_names_and_paths().base_paths[0]
|
||||
|
||||
|
||||
models_dir = os.path.join(get_base_path(), "models")
|
||||
folder_names_and_paths: Final[FolderNames] = FolderNames(models_dir)
|
||||
folder_names_and_paths["checkpoints"] = FolderPathsTuple("checkpoints", [os.path.join(models_dir, "checkpoints")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["configs"] = FolderPathsTuple("configs", [os.path.join(models_dir, "configs"), get_package_as_path("comfy.configs")], {".yaml"})
|
||||
folder_names_and_paths["loras"] = FolderPathsTuple("loras", [os.path.join(models_dir, "loras")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["vae"] = FolderPathsTuple("vae", [os.path.join(models_dir, "vae")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["clip"] = FolderPathsTuple("clip", [os.path.join(models_dir, "clip")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["unet"] = folder_names_and_paths["diffusion_models"] = FolderPathsTuple("diffusion_models", [os.path.join(models_dir, "unet"), os.path.join(models_dir, "diffusion_models")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["clip_vision"] = FolderPathsTuple("clip_vision", [os.path.join(models_dir, "clip_vision")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["style_models"] = FolderPathsTuple("style_models", [os.path.join(models_dir, "style_models")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["embeddings"] = FolderPathsTuple("embeddings", [os.path.join(models_dir, "embeddings")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["diffusers"] = FolderPathsTuple("diffusers", [os.path.join(models_dir, "diffusers")], {"folder"})
|
||||
folder_names_and_paths["vae_approx"] = FolderPathsTuple("vae_approx", [os.path.join(models_dir, "vae_approx")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["controlnet"] = FolderPathsTuple("controlnet", [os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["gligen"] = FolderPathsTuple("gligen", [os.path.join(models_dir, "gligen")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["upscale_models"] = FolderPathsTuple("upscale_models", [os.path.join(models_dir, "upscale_models")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["custom_nodes"] = FolderPathsTuple("custom_nodes", [os.path.join(get_base_path(), "custom_nodes")], set())
|
||||
folder_names_and_paths["hypernetworks"] = FolderPathsTuple("hypernetworks", [os.path.join(models_dir, "hypernetworks")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["photomaker"] = FolderPathsTuple("photomaker", [os.path.join(models_dir, "photomaker")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["classifiers"] = FolderPathsTuple("classifiers", [os.path.join(models_dir, "classifiers")], {""})
|
||||
folder_names_and_paths["huggingface"] = FolderPathsTuple("huggingface", [os.path.join(models_dir, "huggingface")], {""})
|
||||
folder_names_and_paths["huggingface_cache"] = FolderPathsTuple("huggingface_cache", [os.path.join(models_dir, "huggingface_cache")], {""})
|
||||
|
||||
output_directory = os.path.join(get_base_path(), "output")
|
||||
temp_directory = os.path.join(get_base_path(), "temp")
|
||||
input_directory = os.path.join(get_base_path(), "input")
|
||||
user_directory = os.path.join(get_base_path(), "user")
|
||||
|
||||
filename_list_cache = {}
|
||||
def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optional[Configuration] = None):
|
||||
from ..cmd.main_pre import args
|
||||
configuration = configuration or args
|
||||
base_paths = [Path(configuration.cwd) if configuration.cwd is not None else None] + configuration.base_paths
|
||||
base_paths = [path for path in base_paths if path is not None]
|
||||
if len(base_paths) == 0:
|
||||
base_paths = [Path(os.getcwd())]
|
||||
for base_path in base_paths:
|
||||
folder_names_and_paths.add_base_path(base_path)
|
||||
folder_names_and_paths.add(ModelPaths(["checkpoints"], supported_extensions=set(supported_pt_extensions)))
|
||||
folder_names_and_paths.add(ModelPaths(["configs"], additional_absolute_directory_paths={get_package_as_path("comfy.configs")}, supported_extensions={".yaml"}))
|
||||
folder_names_and_paths.add(ModelPaths(["vae"], supported_extensions={".yaml"}))
|
||||
folder_names_and_paths.add(ModelPaths(["clip"], supported_extensions={".yaml"}))
|
||||
folder_names_and_paths.add(ModelPaths(["loras"], supported_extensions={".yaml"}))
|
||||
folder_names_and_paths.add(ModelPaths(["diffusion_models", "unet"], supported_extensions=set(supported_pt_extensions)))
|
||||
folder_names_and_paths.add(ModelPaths(["clip_vision"], supported_extensions=set(supported_pt_extensions)))
|
||||
folder_names_and_paths.add(ModelPaths(["style_models"], supported_extensions=set(supported_pt_extensions)))
|
||||
folder_names_and_paths.add(ModelPaths(["embeddings"], supported_extensions=set(supported_pt_extensions)))
|
||||
folder_names_and_paths.add(ModelPaths(["diffusers"], supported_extensions=set()))
|
||||
folder_names_and_paths.add(ModelPaths(["vae_approx"], supported_extensions=set(supported_pt_extensions)))
|
||||
folder_names_and_paths.add(ModelPaths(["controlnet", "t2i_adapter"], supported_extensions=set(supported_pt_extensions)))
|
||||
folder_names_and_paths.add(ModelPaths(["gligen"], supported_extensions=set(supported_pt_extensions)))
|
||||
folder_names_and_paths.add(ModelPaths(["upscale_models"], supported_extensions=set(supported_pt_extensions)))
|
||||
folder_names_and_paths.add(ModelPaths(["custom_nodes"], folder_name_base_path_subdir=construct_path(""), supported_extensions=set()))
|
||||
folder_names_and_paths.add(ModelPaths(["hypernetworks"], supported_extensions=set(supported_pt_extensions)))
|
||||
folder_names_and_paths.add(ModelPaths(["photomaker"], supported_extensions=set(supported_pt_extensions)))
|
||||
folder_names_and_paths.add(ModelPaths(["classifiers"], supported_extensions=set()))
|
||||
folder_names_and_paths.add(ModelPaths(["huggingface"], supported_extensions=set()))
|
||||
hf_cache_paths = ModelPaths(["huggingface_cache"], supported_extensions=set())
|
||||
# TODO: explore if there is a better way to do this
|
||||
if "HF_HUB_CACHE" in os.environ:
|
||||
hf_cache_paths.additional_absolute_directory_paths.add(os.environ.get("HF_HUB_CACHE"))
|
||||
folder_names_and_paths.add(hf_cache_paths)
|
||||
create_directories(folder_names_and_paths)
|
||||
|
||||
|
||||
class CacheHelper:
|
||||
"""
|
||||
Helper class for managing file list cache data.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
|
||||
self.active = False
|
||||
|
||||
def get(self, key: str, default=None) -> tuple[list[str], dict[str, float], float]:
|
||||
if not self.active:
|
||||
return default
|
||||
return self.cache.get(key, default)
|
||||
|
||||
def set(self, key: str, value: tuple[list[str], dict[str, float], float]) -> None:
|
||||
if self.active:
|
||||
self.cache[key] = value
|
||||
|
||||
def clear(self):
|
||||
self.cache.clear()
|
||||
|
||||
def __enter__(self):
|
||||
self.active = True
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.active = False
|
||||
self.clear()
|
||||
@_module_properties.getter
|
||||
def _folder_names_and_paths():
|
||||
return current_execution_context().folder_names_and_paths
|
||||
|
||||
|
||||
cache_helper = CacheHelper()
|
||||
@_module_properties.getter
|
||||
def _models_dir():
|
||||
return str(Path(current_execution_context().folder_names_and_paths.base_paths[0]) / construct_path("models"))
|
||||
|
||||
|
||||
@_module_properties.getter
|
||||
def _user_directory() -> str:
|
||||
return str(Path(current_execution_context().folder_names_and_paths.application_paths.user_directory).resolve())
|
||||
|
||||
|
||||
@_module_properties.getter
|
||||
def _temp_directory() -> str:
|
||||
return str(Path(current_execution_context().folder_names_and_paths.application_paths.temp_directory).resolve())
|
||||
|
||||
|
||||
@_module_properties.getter
|
||||
def _input_directory() -> str:
|
||||
return str(Path(current_execution_context().folder_names_and_paths.application_paths.input_directory).resolve())
|
||||
|
||||
|
||||
@_module_properties.getter
|
||||
def _output_directory() -> str:
|
||||
return str(Path(current_execution_context().folder_names_and_paths.application_paths.output_directory).resolve())
|
||||
|
||||
|
||||
@_deprecate_method(version="0.2.3", message="Mapping of previous folder names is already done by other mechanisms.")
|
||||
def map_legacy(folder_name: str) -> str:
|
||||
legacy = {"unet": "diffusion_models"}
|
||||
return legacy.get(folder_name, folder_name)
|
||||
|
||||
|
||||
if not os.path.exists(input_directory):
|
||||
try:
|
||||
os.makedirs(input_directory)
|
||||
except:
|
||||
logging.error("Failed to create input directory")
|
||||
def set_output_directory(output_dir: str | Path):
|
||||
_folder_names_and_paths().application_paths.output_directory = construct_path(output_dir)
|
||||
|
||||
|
||||
def set_output_directory(output_dir):
|
||||
global output_directory
|
||||
output_directory = output_dir
|
||||
def set_temp_directory(temp_dir: str | Path):
|
||||
_folder_names_and_paths().application_paths.temp_directory = construct_path(temp_dir)
|
||||
|
||||
|
||||
def set_temp_directory(temp_dir):
|
||||
global temp_directory
|
||||
temp_directory = temp_dir
|
||||
def set_input_directory(input_dir: str | Path):
|
||||
_folder_names_and_paths().application_paths.input_directory = construct_path(input_dir)
|
||||
|
||||
|
||||
def set_input_directory(input_dir):
|
||||
global input_directory
|
||||
input_directory = input_dir
|
||||
def get_output_directory() -> str:
|
||||
return str(Path(_folder_names_and_paths().application_paths.output_directory).resolve())
|
||||
|
||||
|
||||
def get_output_directory():
|
||||
global output_directory
|
||||
return output_directory
|
||||
def get_temp_directory() -> str:
|
||||
return str(Path(_folder_names_and_paths().application_paths.temp_directory).resolve())
|
||||
|
||||
|
||||
def get_temp_directory():
|
||||
global temp_directory
|
||||
return temp_directory
|
||||
|
||||
|
||||
def get_input_directory():
|
||||
global input_directory
|
||||
return input_directory
|
||||
def get_input_directory() -> str:
|
||||
return str(Path(_folder_names_and_paths().application_paths.input_directory).resolve())
|
||||
|
||||
|
||||
def get_user_directory() -> str:
|
||||
return user_directory
|
||||
return str(Path(_folder_names_and_paths().application_paths.user_directory).resolve())
|
||||
|
||||
|
||||
def set_user_directory(user_dir: str) -> None:
|
||||
global user_directory
|
||||
user_directory = user_dir
|
||||
def set_user_directory(user_dir: str | Path) -> None:
|
||||
_folder_names_and_paths().application_paths.user_directory = construct_path(user_dir)
|
||||
|
||||
|
||||
# NOTE: used in http server so don't put folders that should not be accessed remotely
|
||||
def get_directory_by_type(type_name):
|
||||
def get_directory_by_type(type_name) -> str | None:
|
||||
if type_name == "output":
|
||||
return get_output_directory()
|
||||
if type_name == "temp":
|
||||
@ -151,7 +153,7 @@ def get_directory_by_type(type_name):
|
||||
|
||||
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
|
||||
# otherwise use default_path as base_dir
|
||||
def annotated_filepath(name):
|
||||
def annotated_filepath(name: str) -> tuple[str, str | None]:
|
||||
if name.endswith("[output]"):
|
||||
base_dir = get_output_directory()
|
||||
name = name[:-9]
|
||||
@ -167,7 +169,7 @@ def annotated_filepath(name):
|
||||
return name, base_dir
|
||||
|
||||
|
||||
def get_annotated_filepath(name, default_dir=None):
|
||||
def get_annotated_filepath(name, default_dir=None) -> str:
|
||||
name, base_dir = annotated_filepath(name)
|
||||
|
||||
if base_dir is None:
|
||||
@ -198,9 +200,11 @@ def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None, e
|
||||
:param extensions: supported file extensions
|
||||
:return: the folder path
|
||||
"""
|
||||
global folder_names_and_paths
|
||||
folder_names_and_paths = _folder_names_and_paths()
|
||||
if full_folder_path is None:
|
||||
full_folder_path = os.path.join(models_dir, folder_name)
|
||||
# todo: this should use the subdir patter
|
||||
|
||||
full_folder_path = os.path.join(_models_dir(), folder_name)
|
||||
|
||||
folder_path = folder_names_and_paths[folder_name]
|
||||
if full_folder_path not in folder_path.paths:
|
||||
@ -212,48 +216,16 @@ def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None, e
|
||||
if extensions is not None:
|
||||
folder_path.supported_extensions |= extensions
|
||||
|
||||
invalidate_cache(folder_name)
|
||||
return full_folder_path
|
||||
|
||||
|
||||
def get_folder_paths(folder_name) -> List[str]:
|
||||
return folder_names_and_paths[folder_name].paths[:]
|
||||
return [path for path in _folder_names_and_paths()[folder_name].paths]
|
||||
|
||||
|
||||
def recursive_search(directory, excluded_dir_names=None):
|
||||
if not os.path.isdir(directory):
|
||||
return [], {}
|
||||
|
||||
if excluded_dir_names is None:
|
||||
excluded_dir_names = []
|
||||
|
||||
result = []
|
||||
dirs = {}
|
||||
|
||||
# Attempt to add the initial directory to dirs with error handling
|
||||
try:
|
||||
dirs[directory] = os.path.getmtime(directory)
|
||||
except FileNotFoundError:
|
||||
logging.warning(f"Warning: Unable to access {directory}. Skipping this path.")
|
||||
|
||||
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
|
||||
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
|
||||
for file_name in filenames:
|
||||
try:
|
||||
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
|
||||
result.append(relative_path)
|
||||
except:
|
||||
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
|
||||
continue
|
||||
|
||||
for d in subdirs:
|
||||
path = os.path.join(dirpath, d)
|
||||
try:
|
||||
dirs[path] = os.path.getmtime(path)
|
||||
except FileNotFoundError:
|
||||
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
|
||||
continue
|
||||
return result, dirs
|
||||
@_deprecate_method(version="0.2.3", message="Not supported")
|
||||
def recursive_search(directory, excluded_dir_names=None) -> tuple[list[str], dict[str, float]]:
|
||||
raise NotImplemented("Unsupported method")
|
||||
|
||||
|
||||
def filter_files_extensions(files, extensions):
|
||||
@ -264,92 +236,27 @@ def get_full_path(folder_name, filename) -> Optional[str | bytes | os.PathLike]:
|
||||
"""
|
||||
Gets the path to a filename inside a folder.
|
||||
|
||||
Works with untrusted filenames.
|
||||
:param folder_name:
|
||||
:param filename:
|
||||
:return:
|
||||
"""
|
||||
global folder_names_and_paths
|
||||
folders = folder_names_and_paths[folder_name].paths
|
||||
filename_split = os.path.split(filename)
|
||||
|
||||
trusted_paths = []
|
||||
for folder in folders:
|
||||
folder_split = os.path.split(folder)
|
||||
abs_file_path = os.path.abspath(os.path.join(*folder_split, *filename_split))
|
||||
abs_folder_path = os.path.abspath(folder)
|
||||
if os.path.commonpath([abs_file_path, abs_folder_path]) == abs_folder_path:
|
||||
trusted_paths.append(abs_file_path)
|
||||
else:
|
||||
logging.error(f"attempted to access untrusted path {abs_file_path} in {folder_name} for filename {filename}")
|
||||
|
||||
for trusted_path in trusted_paths:
|
||||
if os.path.isfile(trusted_path):
|
||||
return trusted_path
|
||||
|
||||
return None
|
||||
path = _folder_names_and_paths().first_existing_or_none(folder_name, construct_path(filename))
|
||||
return str(path) if path is not None else None
|
||||
|
||||
|
||||
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
|
||||
full_path = get_full_path(folder_name, filename)
|
||||
if full_path is None:
|
||||
# todo: probably shouldn't say model
|
||||
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
|
||||
return full_path
|
||||
|
||||
|
||||
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
|
||||
folder_name = map_legacy(folder_name)
|
||||
global folder_names_and_paths
|
||||
output_list = set()
|
||||
folders = folder_names_and_paths[folder_name]
|
||||
output_folders = {}
|
||||
for x in folders[0]:
|
||||
files, folders_all = recursive_search(x, excluded_dir_names=[".git"])
|
||||
output_list.update(filter_files_extensions(files, folders[1]))
|
||||
output_folders = {**output_folders, **folders_all}
|
||||
|
||||
return sorted(list(output_list)), output_folders, time.perf_counter()
|
||||
|
||||
|
||||
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
|
||||
strong_cache = cache_helper.get(folder_name)
|
||||
if strong_cache is not None:
|
||||
return strong_cache
|
||||
|
||||
global filename_list_cache
|
||||
global folder_names_and_paths
|
||||
folder_name = map_legacy(folder_name)
|
||||
if folder_name not in filename_list_cache:
|
||||
return None
|
||||
out = filename_list_cache[folder_name]
|
||||
|
||||
for x in out[1]:
|
||||
time_modified = out[1][x]
|
||||
folder = x
|
||||
if os.path.getmtime(folder) != time_modified:
|
||||
return None
|
||||
|
||||
folders = folder_names_and_paths[folder_name]
|
||||
for x in folders[0]:
|
||||
if os.path.isdir(x):
|
||||
if x not in out[1]:
|
||||
return None
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def get_filename_list(folder_name: str) -> list[str]:
|
||||
folder_name = map_legacy(folder_name)
|
||||
out = cached_filename_list_(folder_name)
|
||||
if out is None:
|
||||
out = get_filename_list_(folder_name)
|
||||
global filename_list_cache
|
||||
filename_list_cache[folder_name] = out
|
||||
cache_helper.set(folder_name, out)
|
||||
return list(out[0])
|
||||
return [str(path) for path in _folder_names_and_paths().file_paths(folder_name=folder_name, relative=True)]
|
||||
|
||||
|
||||
def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0):
|
||||
def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0) -> SaveImagePathTuple:
|
||||
def map_filename(filename: str) -> tuple[int, str]:
|
||||
prefix_len = len(os.path.basename(filename_prefix))
|
||||
prefix = filename[:prefix_len + 1]
|
||||
@ -378,14 +285,6 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height
|
||||
|
||||
full_output_folder = str(os.path.join(output_dir, subfolder))
|
||||
|
||||
if str(os.path.commonpath((output_dir, os.path.abspath(full_output_folder)))) != str(output_dir):
|
||||
err = f"""**** ERROR: Saving image outside the output folder is not allowed.
|
||||
full_output_folder: {os.path.abspath(full_output_folder)}
|
||||
output_dir: {output_dir}
|
||||
commonpath: {os.path.commonpath((output_dir, os.path.abspath(full_output_folder)))}"""
|
||||
logging.error(err)
|
||||
raise Exception(err)
|
||||
|
||||
try:
|
||||
counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1
|
||||
except ValueError:
|
||||
@ -393,21 +292,22 @@ def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height
|
||||
except FileNotFoundError:
|
||||
os.makedirs(full_output_folder, exist_ok=True)
|
||||
counter = 1
|
||||
return SaveImagePathResponse(full_output_folder, filename, counter, subfolder, filename_prefix)
|
||||
return SaveImagePathTuple(full_output_folder, filename, counter, subfolder, filename_prefix)
|
||||
|
||||
|
||||
def create_directories():
|
||||
def create_directories(paths: FolderNames | None):
|
||||
# all configured paths should be created
|
||||
for folder_path_spec in folder_names_and_paths.values():
|
||||
paths = paths or _folder_names_and_paths()
|
||||
for folder_path_spec in paths.values():
|
||||
for path in folder_path_spec.paths:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
for path in (temp_directory, input_directory, output_directory, user_directory):
|
||||
os.makedirs(path, exist_ok=True)
|
||||
for path in paths.application_paths:
|
||||
path.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
@_deprecate_method(version="0.2.3", message="Caching has been removed.")
|
||||
def invalidate_cache(folder_name):
|
||||
global filename_list_cache
|
||||
filename_list_cache.pop(folder_name, None)
|
||||
pass
|
||||
|
||||
|
||||
def filter_files_content_types(files: list[str], content_types: list[Literal["image", "video", "audio"]]) -> list[str]:
|
||||
@ -416,7 +316,7 @@ def filter_files_content_types(files: list[str], content_types: list[Literal["im
|
||||
files = os.listdir(folder_paths.get_input_directory())
|
||||
filter_files_content_types(files, ["image", "audio", "video"])
|
||||
"""
|
||||
global extension_mimetypes_cache
|
||||
extension_mimetypes_cache = _extension_mimetypes_cache()
|
||||
result = []
|
||||
for file in files:
|
||||
extension = file.split('.')[-1]
|
||||
@ -432,3 +332,52 @@ def filter_files_content_types(files: list[str], content_types: list[Literal["im
|
||||
if content_type in content_types:
|
||||
result.append(file)
|
||||
return result
|
||||
|
||||
|
||||
@_module_properties.getter
|
||||
def _cache_helper():
|
||||
return nullcontext()
|
||||
|
||||
|
||||
# todo: can this be done side effect free?
|
||||
init_default_paths(_folder_names_and_paths())
|
||||
|
||||
__all__ = [
|
||||
# Properties (stripped leading underscore)
|
||||
"supported_pt_extensions", # from _supported_pt_extensions
|
||||
"extension_mimetypes_cache", # from _extension_mimetypes_cache
|
||||
"base_path", # from _base_path
|
||||
"folder_names_and_paths", # from _folder_names_and_paths
|
||||
"models_dir", # from _models_dir
|
||||
"user_directory",
|
||||
"output_directory",
|
||||
"temp_directory",
|
||||
"input_directory",
|
||||
|
||||
# Public functions
|
||||
"init_default_paths",
|
||||
"map_legacy",
|
||||
"set_output_directory",
|
||||
"set_temp_directory",
|
||||
"set_input_directory",
|
||||
"get_output_directory",
|
||||
"get_temp_directory",
|
||||
"get_input_directory",
|
||||
"get_user_directory",
|
||||
"set_user_directory",
|
||||
"get_directory_by_type",
|
||||
"annotated_filepath",
|
||||
"get_annotated_filepath",
|
||||
"exists_annotated_filepath",
|
||||
"add_model_folder_path",
|
||||
"get_folder_paths",
|
||||
"recursive_search",
|
||||
"filter_files_extensions",
|
||||
"get_full_path",
|
||||
"get_full_path_or_raise",
|
||||
"get_filename_list",
|
||||
"get_save_image_path",
|
||||
"create_directories",
|
||||
"invalidate_cache",
|
||||
"filter_files_content_types"
|
||||
]
|
||||
|
||||
@ -1,31 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from ..cli_args import args
|
||||
|
||||
_base_path = None
|
||||
|
||||
|
||||
# todo: this should be initialized elsewhere in a context
|
||||
def get_base_path() -> str:
|
||||
global _base_path
|
||||
if _base_path is None:
|
||||
if args.cwd is not None:
|
||||
if not os.path.exists(args.cwd):
|
||||
try:
|
||||
os.makedirs(args.cwd, exist_ok=True)
|
||||
except:
|
||||
logging.error("Failed to create custom working directory")
|
||||
# wrap the path to prevent slashedness from glitching out common path checks
|
||||
_base_path = os.path.realpath(args.cwd)
|
||||
else:
|
||||
_base_path = os.getcwd()
|
||||
return _base_path
|
||||
|
||||
|
||||
def set_base_path(value: str):
|
||||
global _base_path
|
||||
_base_path = value
|
||||
|
||||
|
||||
__all__ = ["get_base_path", "set_base_path"]
|
||||
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import contextvars
|
||||
import gc
|
||||
import itertools
|
||||
import logging
|
||||
@ -193,7 +194,9 @@ async def main(from_script_dir: Optional[Path] = None):
|
||||
if not distributed or args.distributed_queue_worker:
|
||||
if distributed:
|
||||
logging.warning(f"Distributed workers started in the default thread loop cannot notify clients of progress updates. Instead of comfyui or main.py, use comfyui-worker.")
|
||||
threading.Thread(target=prompt_worker, daemon=True, args=(q, worker_thread_server,)).start()
|
||||
# todo: this should really be using an executor instead of doing things this jankilicious way
|
||||
ctx = contextvars.copy_context()
|
||||
threading.Thread(target=lambda _q, _worker_thread_server: ctx.run(prompt_worker, _q, _worker_thread_server), daemon=True, args=(q, worker_thread_server,)).start()
|
||||
|
||||
# server has been imported and things should be looking good
|
||||
initialize_event_tracking(loop)
|
||||
|
||||
@ -111,13 +111,7 @@ def _create_tracer():
|
||||
|
||||
|
||||
def _configure_logging():
|
||||
if isinstance(args.verbose, str):
|
||||
logging_level = args.verbose
|
||||
elif args.verbose == True:
|
||||
logging_level = logging.DEBUG
|
||||
else:
|
||||
logging_level = logging.ERROR
|
||||
|
||||
logging_level = args.logging_level
|
||||
logging.basicConfig(format="%(message)s", level=logging_level)
|
||||
|
||||
|
||||
|
||||
@ -28,8 +28,8 @@ from aiohttp import web
|
||||
from can_ada import URL, parse as urlparse # pylint: disable=no-name-in-module
|
||||
from typing_extensions import NamedTuple
|
||||
|
||||
from .. import __version__
|
||||
from .latent_preview_image_encoding import encode_preview_image
|
||||
from .. import __version__
|
||||
from .. import interruption, model_management
|
||||
from .. import node_helpers
|
||||
from .. import utils
|
||||
@ -241,7 +241,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
return response
|
||||
|
||||
@routes.get("/embeddings")
|
||||
def get_embeddings(self):
|
||||
def get_embeddings(request):
|
||||
embeddings = folder_paths.get_filename_list("embeddings")
|
||||
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
|
||||
|
||||
@ -568,15 +568,14 @@ class PromptServer(ExecutorToClientProgress):
|
||||
|
||||
@routes.get("/object_info")
|
||||
async def get_object_info(request):
|
||||
with folder_paths.cache_helper:
|
||||
out = {}
|
||||
for x in self.nodes.NODE_CLASS_MAPPINGS:
|
||||
try:
|
||||
out[x] = node_info(x)
|
||||
except Exception as e:
|
||||
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
|
||||
logging.error(traceback.format_exc())
|
||||
return web.json_response(out)
|
||||
out = {}
|
||||
for x in self.nodes.NODE_CLASS_MAPPINGS:
|
||||
try:
|
||||
out[x] = node_info(x)
|
||||
except Exception as e:
|
||||
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
|
||||
logging.error(traceback.format_exc())
|
||||
return web.json_response(out)
|
||||
|
||||
@routes.get("/object_info/{node_class}")
|
||||
async def get_object_info_node(request):
|
||||
@ -721,21 +720,21 @@ class PromptServer(ExecutorToClientProgress):
|
||||
async def get_api_v1_prompts_prompt_id(request: web.Request) -> web.Response | web.FileResponse:
|
||||
prompt_id: str = request.match_info.get("prompt_id", "")
|
||||
if prompt_id == "":
|
||||
return web.Response(status=404)
|
||||
return web.json_response(status=404)
|
||||
|
||||
history_items = self.prompt_queue.get_history(prompt_id)
|
||||
if len(history_items) == 0 or prompt_id not in history_items:
|
||||
# todo: this should really be moved to a stateful queue abstraction
|
||||
if prompt_id in self.background_tasks:
|
||||
return web.Response(status=204)
|
||||
return web.json_response(status=204)
|
||||
else:
|
||||
# todo: this should check a stateful queue abstraction
|
||||
return web.Response(status=404)
|
||||
return web.json_response(status=404)
|
||||
elif prompt_id in history_items:
|
||||
history_entry = history_items[prompt_id]
|
||||
return web.json_response(history_entry["outputs"])
|
||||
else:
|
||||
return web.Response(status=500)
|
||||
return web.json_response(status=500)
|
||||
|
||||
@routes.post("/api/v1/prompts")
|
||||
async def post_api_prompt(request: web.Request) -> web.Response | web.FileResponse:
|
||||
@ -751,8 +750,8 @@ class PromptServer(ExecutorToClientProgress):
|
||||
queue_size = self.prompt_queue.size()
|
||||
queue_too_busy_size = PromptServer.get_too_busy_queue_size()
|
||||
if queue_size > queue_too_busy_size:
|
||||
return web.Response(status=429,
|
||||
reason=f"the queue has {queue_size} elements and {queue_too_busy_size} is the limit for this worker")
|
||||
return web.json_response(status=429,
|
||||
reason=f"the queue has {queue_size} elements and {queue_too_busy_size} is the limit for this worker")
|
||||
# read the request
|
||||
prompt_dict: dict = {}
|
||||
if content_type == 'application/json':
|
||||
|
||||
@ -2,10 +2,11 @@ import asyncio
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
from .extra_model_paths import load_extra_path_config
|
||||
from .main_pre import args
|
||||
from .extra_model_paths import load_extra_path_config
|
||||
from ..distributed.executors import ContextVarExecutor
|
||||
|
||||
|
||||
async def main():
|
||||
@ -43,7 +44,7 @@ async def main():
|
||||
from ..distributed.distributed_prompt_worker import DistributedPromptWorker
|
||||
async with DistributedPromptWorker(connection_uri=args.distributed_queue_connection_uri,
|
||||
queue_name=args.distributed_queue_name,
|
||||
executor=ThreadPoolExecutor(max_workers=1) if args.executor_factory == "ThreadPoolExecutor" else ProcessPoolExecutor(max_workers=1)):
|
||||
executor=ContextVarExecutor(max_workers=1) if args.executor_factory == "ThreadPoolExecutor" else ProcessPoolExecutor(max_workers=1)):
|
||||
stop = asyncio.Event()
|
||||
try:
|
||||
await stop.wait()
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Optional, Literal, Protocol, Union, NamedTuple, List
|
||||
import PIL.Image
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from .outputs_types import OutputsDict
|
||||
from .queue_types import BinaryEventTypes
|
||||
from ..cli_args_types import Configuration
|
||||
from ..nodes.package_typing import InputTypeSpec
|
||||
@ -205,8 +206,8 @@ class DuplicateNodeError(Exception):
|
||||
|
||||
|
||||
class HistoryResultDict(TypedDict, total=True):
|
||||
outputs: dict
|
||||
meta: dict
|
||||
outputs: OutputsDict
|
||||
meta: OutputsDict
|
||||
|
||||
|
||||
class DependencyCycleError(Exception):
|
||||
|
||||
@ -1,15 +1,9 @@
|
||||
import os
|
||||
from typing import Literal, Optional
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional
|
||||
|
||||
from ..cmd import folder_paths
|
||||
|
||||
|
||||
def _is_strictly_below_root(path: Path) -> bool:
|
||||
resolved_path = path.resolve()
|
||||
return ".." not in resolved_path.parts and resolved_path.is_absolute()
|
||||
|
||||
|
||||
def file_output_path(filename: str, type: Literal["input", "output", "temp"] = "output",
|
||||
subfolder: Optional[str] = None) -> str:
|
||||
"""
|
||||
@ -22,22 +16,15 @@ def file_output_path(filename: str, type: Literal["input", "output", "temp"] = "
|
||||
:return:
|
||||
"""
|
||||
filename, output_dir = folder_paths.annotated_filepath(str(filename))
|
||||
if not _is_strictly_below_root(Path(filename)):
|
||||
raise PermissionError("insecure")
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = folder_paths.get_directory_by_type(type)
|
||||
if output_dir is None:
|
||||
raise ValueError(f"no such output directory because invalid type specified (type={type})")
|
||||
if subfolder is not None and subfolder != "":
|
||||
full_output_dir = str(os.path.join(output_dir, subfolder))
|
||||
if str(os.path.commonpath([os.path.abspath(full_output_dir), output_dir])) != str(output_dir):
|
||||
raise PermissionError("insecure")
|
||||
output_dir = full_output_dir
|
||||
filename = os.path.basename(filename)
|
||||
else:
|
||||
if str(os.path.commonpath([os.path.abspath(output_dir), os.path.join(output_dir, filename)])) != str(output_dir):
|
||||
raise PermissionError("insecure")
|
||||
|
||||
file = os.path.join(output_dir, filename)
|
||||
return file
|
||||
output_dir = Path(output_dir)
|
||||
subfolder = Path(subfolder or "")
|
||||
try:
|
||||
relative_to = (output_dir / subfolder / filename).relative_to(output_dir)
|
||||
except ValueError:
|
||||
raise PermissionError(f"{output_dir / subfolder / filename} is not a subpath of {output_dir}")
|
||||
return str((output_dir / relative_to).resolve(strict=True))
|
||||
|
||||
@ -1,9 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import itertools
|
||||
import os
|
||||
import typing
|
||||
from typing import List, Set, Any, Iterator, Sequence, Dict, NamedTuple
|
||||
import weakref
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, NamedTuple, Optional, Iterable
|
||||
|
||||
from pathlib import Path
|
||||
from .platform_path import construct_path
|
||||
|
||||
supported_pt_extensions = frozenset(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft', ".index.json"])
|
||||
extension_mimetypes_cache = {
|
||||
@ -11,35 +17,267 @@ extension_mimetypes_cache = {
|
||||
}
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
def do_add(collection: list | set, index: int | None, item: Any):
|
||||
if isinstance(collection, list) and index == 0:
|
||||
collection.insert(0, item)
|
||||
elif isinstance(collection, set):
|
||||
collection.add(item)
|
||||
else:
|
||||
assert isinstance(collection, list)
|
||||
collection.append(item)
|
||||
|
||||
|
||||
class FolderPathsTuple:
|
||||
folder_name: str
|
||||
paths: List[str] = dataclasses.field(default_factory=list)
|
||||
supported_extensions: Set[str] = dataclasses.field(default_factory=lambda: set(supported_pt_extensions))
|
||||
def __init__(self, folder_name: str = None, paths: list[str] = None, supported_extensions: set[str] = None, parent: Optional[weakref.ReferenceType[FolderNames]] = None):
|
||||
paths = paths or []
|
||||
supported_extensions = supported_extensions or set(supported_pt_extensions)
|
||||
|
||||
self.folder_name = folder_name
|
||||
self.parent = parent
|
||||
self._paths = paths
|
||||
self._supported_extensions = supported_extensions
|
||||
|
||||
@property
|
||||
def supported_extensions(self) -> set[str] | SupportedExtensions:
|
||||
if self.parent is not None and self.folder_name is not None:
|
||||
return SupportedExtensions(self.folder_name, self.parent)
|
||||
else:
|
||||
return self._supported_extensions
|
||||
|
||||
@supported_extensions.setter
|
||||
def supported_extensions(self, value: set[str] | SupportedExtensions):
|
||||
self.supported_extensions.clear()
|
||||
self.supported_extensions.update(value)
|
||||
|
||||
@property
|
||||
def paths(self) -> list[str] | PathsList:
|
||||
if self.parent is not None and self.folder_name is not None:
|
||||
return PathsList(self.folder_name, self.parent)
|
||||
else:
|
||||
return self._paths
|
||||
|
||||
def __iter__(self) -> typing.Generator[typing.Iterable[str]]:
|
||||
"""
|
||||
allows this proxy to behave like a tuple everywhere it is used by the custom nodes
|
||||
:return:
|
||||
"""
|
||||
yield self.paths
|
||||
yield self.supported_extensions
|
||||
|
||||
def __getitem__(self, item: Any):
|
||||
if item == 0:
|
||||
return self.paths
|
||||
if item == 1:
|
||||
return self.supported_extensions
|
||||
else:
|
||||
raise RuntimeError("unsupported tuple index")
|
||||
|
||||
def __add__(self, other: "FolderPathsTuple"):
|
||||
assert self.folder_name == other.folder_name
|
||||
# todo: make sure the paths are actually unique, as this method intends
|
||||
new_paths = list(frozenset(self.paths + other.paths))
|
||||
new_supported_extensions = self.supported_extensions | other.supported_extensions
|
||||
return FolderPathsTuple(self.folder_name, new_paths, new_supported_extensions)
|
||||
raise RuntimeError("unsupported tuple index")
|
||||
|
||||
def __iter__(self) -> Iterator[Sequence[str]]:
|
||||
yield self.paths
|
||||
yield self.supported_extensions
|
||||
def __iadd__(self, other: FolderPathsTuple):
|
||||
for path in other.paths:
|
||||
self.paths.append(path)
|
||||
|
||||
for ext in other.supported_extensions:
|
||||
self.supported_extensions.add(ext)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PathsList:
|
||||
folder_name: str
|
||||
parent: weakref.ReferenceType[FolderNames]
|
||||
|
||||
def __iter__(self) -> typing.Generator[str]:
|
||||
p: FolderNames = self.parent()
|
||||
for path in p.directory_paths(self.folder_name):
|
||||
try:
|
||||
yield str(path.resolve())
|
||||
except (OSError, AttributeError):
|
||||
yield str(path)
|
||||
|
||||
def __getitem__(self, item: int):
|
||||
paths = [x for x in self]
|
||||
return paths[item]
|
||||
|
||||
def append(self, path_str: str):
|
||||
p: FolderNames = self.parent()
|
||||
p.add_paths(self.folder_name, [path_str])
|
||||
|
||||
def insert(self, path_str: str, index: int):
|
||||
p: FolderNames = self.parent()
|
||||
p.add_paths(self.folder_name, [path_str], index=index)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SupportedExtensions:
|
||||
folder_name: str
|
||||
parent: weakref.ReferenceType[FolderNames]
|
||||
|
||||
def _append_any(self, other):
|
||||
if other is None:
|
||||
return
|
||||
|
||||
p: FolderNames = self.parent()
|
||||
if isinstance(other, str):
|
||||
other = {other}
|
||||
p.add_supported_extension(self.folder_name, *other)
|
||||
|
||||
def __iter__(self) -> typing.Generator[str]:
|
||||
p: FolderNames = self.parent()
|
||||
for ext in p.supported_extensions(self.folder_name):
|
||||
yield ext
|
||||
|
||||
def clear(self):
|
||||
p: FolderNames = self.parent()
|
||||
p.remove_all_supported_extensions(self.folder_name)
|
||||
|
||||
__ior__ = _append_any
|
||||
add = _append_any
|
||||
update = _append_any
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ApplicationPaths:
|
||||
output_directory: Path = dataclasses.field(default_factory=lambda: construct_path("output"))
|
||||
temp_directory: Path = dataclasses.field(default_factory=lambda: construct_path("temp"))
|
||||
input_directory: Path = dataclasses.field(default_factory=lambda: construct_path("input"))
|
||||
user_directory: Path = dataclasses.field(default_factory=lambda: construct_path("user"))
|
||||
|
||||
def __iter__(self) -> typing.Generator[Path]:
|
||||
yield self.output_directory
|
||||
yield self.temp_directory
|
||||
yield self.input_directory
|
||||
yield self.user_directory
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AbstractPaths(ABC):
|
||||
folder_names: list[str]
|
||||
supported_extensions: set[str] = dataclasses.field(default_factory=lambda: set(supported_pt_extensions))
|
||||
|
||||
@abstractmethod
|
||||
def directory_paths(self, base_paths: Iterable[Path]) -> typing.Generator[Path]:
|
||||
"""Generate directory paths based on the given base paths."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def file_paths(self, base_paths: Iterable[Path], relative=False) -> typing.Generator[Path]:
|
||||
"""Generate file paths based on the given base paths."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def has_folder_name(self, folder_name: str) -> bool:
|
||||
"""Check if the given folder name is in folder_names."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModelPaths(AbstractPaths):
|
||||
folder_name_base_path_subdir: Path = dataclasses.field(default_factory=lambda: construct_path("models"))
|
||||
additional_relative_directory_paths: set[Path] = dataclasses.field(default_factory=set)
|
||||
additional_absolute_directory_paths: set[str | Path] = dataclasses.field(default_factory=set)
|
||||
folder_names_are_relative_directory_paths_too: bool = dataclasses.field(default_factory=lambda: True)
|
||||
|
||||
def directory_paths(self, base_paths: Iterable[Path]) -> typing.Generator[Path]:
|
||||
yielded_so_far: set[Path] = set()
|
||||
for base_path in base_paths:
|
||||
if self.folder_names_are_relative_directory_paths_too:
|
||||
for folder_name in self.folder_names:
|
||||
resolved_default_folder_name_path = base_path / self.folder_name_base_path_subdir / folder_name
|
||||
if not resolved_default_folder_name_path in yielded_so_far:
|
||||
yield resolved_default_folder_name_path
|
||||
yielded_so_far.add(resolved_default_folder_name_path)
|
||||
|
||||
for additional_relative_directory_path in self.additional_relative_directory_paths:
|
||||
resolved_additional_relative_path = base_path / additional_relative_directory_path
|
||||
if not resolved_additional_relative_path in yielded_so_far:
|
||||
yield resolved_additional_relative_path
|
||||
yielded_so_far.add(resolved_additional_relative_path)
|
||||
|
||||
# resolve all paths
|
||||
yielded_so_far = {path.resolve() for path in yielded_so_far}
|
||||
for additional_absolute_path in self.additional_absolute_directory_paths:
|
||||
try:
|
||||
resolved_absolute_path = additional_absolute_path.resolve()
|
||||
except (OSError, AttributeError):
|
||||
resolved_absolute_path = additional_absolute_path
|
||||
if not resolved_absolute_path in yielded_so_far:
|
||||
yield resolved_absolute_path
|
||||
yielded_so_far.add(resolved_absolute_path)
|
||||
|
||||
def file_paths(self, base_paths: Iterable[Path], relative=False) -> typing.Generator[Path]:
|
||||
for path in self.directory_paths(base_paths):
|
||||
for dirpath, dirnames, filenames in os.walk(path, followlinks=True):
|
||||
if '.git' in dirnames:
|
||||
dirnames.remove('.git')
|
||||
|
||||
for filename in filenames:
|
||||
if any(filename.endswith(ext) for ext in self.supported_extensions):
|
||||
result_path = construct_path(dirpath) / filename
|
||||
if relative:
|
||||
yield result_path.relative_to(path)
|
||||
else:
|
||||
yield result_path
|
||||
|
||||
def has_folder_name(self, folder_name: str) -> bool:
|
||||
return folder_name in self.folder_names
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FolderNames:
|
||||
application_paths: typing.Optional[ApplicationPaths] = dataclasses.field(default_factory=ApplicationPaths)
|
||||
contents: list[AbstractPaths] = dataclasses.field(default_factory=list)
|
||||
base_paths: list[Path] = dataclasses.field(default_factory=list)
|
||||
|
||||
def supported_extensions(self, folder_name: str) -> typing.Generator[str]:
|
||||
for candidate in self.contents:
|
||||
if candidate.has_folder_name(folder_name):
|
||||
for supported_extensions in candidate.supported_extensions:
|
||||
for supported_extension in supported_extensions:
|
||||
yield supported_extension
|
||||
|
||||
def directory_paths(self, folder_name: str) -> typing.Generator[Path]:
|
||||
for directory_path in itertools.chain.from_iterable([candidate.directory_paths(self.base_paths)
|
||||
for candidate in self.contents if candidate.has_folder_name(folder_name)]):
|
||||
yield directory_path
|
||||
|
||||
def file_paths(self, folder_name: str, relative=False) -> typing.Generator[Path]:
|
||||
for file_path in itertools.chain.from_iterable([candidate.file_paths(self.base_paths, relative=relative)
|
||||
for candidate in self.contents if candidate.has_folder_name(folder_name)]):
|
||||
yield file_path
|
||||
|
||||
def first_existing_or_none(self, folder_name: str, relative_file_path: Path) -> Optional[Path]:
|
||||
for directory_path in itertools.chain.from_iterable([candidate.directory_paths(self.base_paths)
|
||||
for candidate in self.contents if candidate.has_folder_name(folder_name)]):
|
||||
candidate_file_path: Path = construct_path(directory_path / relative_file_path)
|
||||
try:
|
||||
# todo: this should follow the symlink
|
||||
if Path.exists(candidate_file_path):
|
||||
return candidate_file_path
|
||||
except OSError:
|
||||
continue
|
||||
return None
|
||||
|
||||
def add_supported_extension(self, folder_name: str, *supported_extensions: str | None):
|
||||
if supported_extensions is None:
|
||||
return
|
||||
|
||||
for candidate in self.contents:
|
||||
if candidate.has_folder_name(folder_name):
|
||||
candidate.supported_extensions.update(supported_extensions)
|
||||
|
||||
def remove_all_supported_extensions(self, folder_name: str):
|
||||
for candidate in self.contents:
|
||||
if candidate.has_folder_name(folder_name):
|
||||
candidate.supported_extensions.clear()
|
||||
|
||||
def add(self, model_paths: AbstractPaths):
|
||||
self.contents.append(model_paths)
|
||||
|
||||
def add_base_path(self, base_path: Path):
|
||||
if base_path not in self.base_paths:
|
||||
self.base_paths.append(base_path)
|
||||
|
||||
@staticmethod
|
||||
def from_dict(folder_paths_dict: dict[str, tuple[typing.Sequence[str], Sequence[str]]] = None) -> FolderNames:
|
||||
def from_dict(folder_paths_dict: dict[str, tuple[typing.Iterable[str], Iterable[str]]] = None) -> FolderNames:
|
||||
"""
|
||||
Turns a dictionary of
|
||||
{
|
||||
@ -50,65 +288,179 @@ class FolderNames:
|
||||
:param folder_paths_dict: A dictionary
|
||||
:return: A FolderNames object
|
||||
"""
|
||||
if folder_paths_dict is None:
|
||||
return FolderNames(os.getcwd())
|
||||
raise NotImplementedError()
|
||||
|
||||
fn = FolderNames(os.getcwd())
|
||||
for folder_name, (paths, extensions) in folder_paths_dict.items():
|
||||
paths_tuple = FolderPathsTuple(folder_name=folder_name, paths=list(paths), supported_extensions=set(extensions))
|
||||
|
||||
if folder_name in fn:
|
||||
fn[folder_name] += paths_tuple
|
||||
else:
|
||||
fn[folder_name] = paths_tuple
|
||||
return fn
|
||||
|
||||
def __init__(self, default_new_folder_path: str):
|
||||
self.contents: Dict[str, FolderPathsTuple] = dict()
|
||||
self.default_new_folder_path = default_new_folder_path
|
||||
|
||||
def __getitem__(self, item) -> FolderPathsTuple:
|
||||
if not isinstance(item, str):
|
||||
def __getitem__(self, folder_name) -> FolderPathsTuple:
|
||||
if not isinstance(folder_name, str) or folder_name is None:
|
||||
raise RuntimeError("expected folder path")
|
||||
if item not in self.contents:
|
||||
default_path = os.path.join(self.default_new_folder_path, item)
|
||||
os.makedirs(default_path, exist_ok=True)
|
||||
self.contents[item] = FolderPathsTuple(item, paths=[default_path], supported_extensions=set())
|
||||
return self.contents[item]
|
||||
# todo: it is probably never the intention to do this
|
||||
try:
|
||||
path = Path(folder_name)
|
||||
if path.is_absolute():
|
||||
folder_name = path.stem
|
||||
except Exception:
|
||||
pass
|
||||
if not any(candidate.has_folder_name(folder_name) for candidate in self.contents):
|
||||
self.add(ModelPaths(folder_names=[folder_name], folder_name_base_path_subdir=construct_path(), supported_extensions=set(), folder_names_are_relative_directory_paths_too=False))
|
||||
return FolderPathsTuple(folder_name, parent=weakref.ref(self))
|
||||
|
||||
def __setitem__(self, key: str, value: FolderPathsTuple):
|
||||
assert isinstance(key, str)
|
||||
if isinstance(value, tuple):
|
||||
def add_paths(self, folder_name: str, paths: list[Path | str], index: Optional[int] = None):
|
||||
"""
|
||||
Adds, but does not create, new model paths
|
||||
:param folder_name:
|
||||
:param paths:
|
||||
:param index:
|
||||
:return:
|
||||
"""
|
||||
for candidate in self.contents:
|
||||
if candidate.has_folder_name(folder_name):
|
||||
self._modify_model_paths(folder_name, paths, set(), candidate, index=index)
|
||||
|
||||
def _modify_model_paths(self, key: str, paths: Iterable[Path | str], supported_extensions: set[str], model_paths: AbstractPaths = None, index: Optional[int] = None) -> AbstractPaths:
|
||||
model_paths = model_paths or ModelPaths([key],
|
||||
supported_extensions=set(supported_extensions),
|
||||
folder_names_are_relative_directory_paths_too=False)
|
||||
if index is not None and index != 0:
|
||||
raise ValueError(f"index was {index} but only 0 or None is supported")
|
||||
|
||||
for path in paths:
|
||||
if isinstance(path, str):
|
||||
path = construct_path(path)
|
||||
# when given absolute paths, try to formulate them as relative paths anyway
|
||||
if path.is_absolute():
|
||||
for base_path in self.base_paths:
|
||||
try:
|
||||
relative_to_basepath = path.relative_to(base_path)
|
||||
potential_folder_name = relative_to_basepath.stem
|
||||
potential_subdir = relative_to_basepath.parent
|
||||
|
||||
# does the folder_name so far match the key?
|
||||
# or have we not seen this folder before?
|
||||
folder_name_not_already_in_contents = all(not candidate.has_folder_name(potential_folder_name) for candidate in self.contents)
|
||||
if potential_folder_name == key or folder_name_not_already_in_contents:
|
||||
# fix the subdir
|
||||
model_paths.folder_name_base_path_subdir = potential_subdir
|
||||
model_paths.folder_names_are_relative_directory_paths_too = True
|
||||
if folder_name_not_already_in_contents:
|
||||
do_add(model_paths.folder_names, index, potential_folder_name)
|
||||
else:
|
||||
# if the folder name doesn't match the key, check if we have ever seen a folder name that matches the key:
|
||||
if model_paths.folder_names_are_relative_directory_paths_too:
|
||||
if potential_subdir == model_paths.folder_name_base_path_subdir:
|
||||
# then we want to add this to the folder name, because it's probably compatibility
|
||||
do_add(model_paths.folder_names, index, potential_folder_name)
|
||||
else:
|
||||
# not this case
|
||||
model_paths.folder_names_are_relative_directory_paths_too = False
|
||||
|
||||
if not model_paths.folder_names_are_relative_directory_paths_too:
|
||||
model_paths.additional_relative_directory_paths.add(relative_to_basepath)
|
||||
for resolve_folder_name in model_paths.folder_names:
|
||||
model_paths.additional_relative_directory_paths.add(model_paths.folder_name_base_path_subdir / resolve_folder_name)
|
||||
|
||||
# since this was an absolute path that was a subdirectory of one of the base paths,
|
||||
# we are done
|
||||
break
|
||||
except ValueError:
|
||||
# this is not a subpath of the base path
|
||||
pass
|
||||
|
||||
# if we got this far, none of the absolute paths were subdirectories of any base paths
|
||||
# add it to our absolute paths
|
||||
model_paths.additional_absolute_directory_paths.add(path)
|
||||
else:
|
||||
# since this is a relative path, peacefully add it to model_paths
|
||||
potential_folder_name = path.stem
|
||||
|
||||
try:
|
||||
relative_to_current_subdir = path.relative_to(model_paths.folder_name_base_path_subdir)
|
||||
|
||||
# if relative to the current subdir, we observe only one part, we're good to go
|
||||
if len(relative_to_current_subdir.parts) == 1:
|
||||
if potential_folder_name == key:
|
||||
model_paths.folder_names_are_relative_directory_paths_too = True
|
||||
else:
|
||||
# if there already exists a folder_name by this name, do not add it, and switch to all relative paths
|
||||
if any(candidate.has_folder_name(potential_folder_name) for candidate in self.contents):
|
||||
model_paths.folder_names_are_relative_directory_paths_too = False
|
||||
model_paths.additional_relative_directory_paths.add(path)
|
||||
model_paths.folder_name_base_path_subdir = construct_path()
|
||||
else:
|
||||
do_add(model_paths.folder_names, index, potential_folder_name)
|
||||
except ValueError:
|
||||
# this means that the relative directory didn't contain the subdir so far
|
||||
# something_not_models/key
|
||||
if potential_folder_name == key:
|
||||
model_paths.folder_name_base_path_subdir = path.parent
|
||||
model_paths.folder_names_are_relative_directory_paths_too = True
|
||||
else:
|
||||
if any(candidate.has_folder_name(potential_folder_name) for candidate in self.contents):
|
||||
model_paths.folder_names_are_relative_directory_paths_too = False
|
||||
model_paths.additional_relative_directory_paths.add(path)
|
||||
model_paths.folder_name_base_path_subdir = construct_path()
|
||||
else:
|
||||
do_add(model_paths.folder_names, index, potential_folder_name)
|
||||
return model_paths
|
||||
|
||||
def __setitem__(self, key: str, value: tuple | FolderPathsTuple | AbstractPaths):
|
||||
# remove all existing paths for this key
|
||||
self.__delitem__(key)
|
||||
|
||||
if isinstance(value, AbstractPaths):
|
||||
self.contents.append(value)
|
||||
elif isinstance(value, (tuple, FolderPathsTuple)):
|
||||
paths, supported_extensions = value
|
||||
value = FolderPathsTuple(key, paths, supported_extensions)
|
||||
self.contents[key] = value
|
||||
# typical cases:
|
||||
# key="checkpoints", paths="C:/base_path/models/checkpoints"
|
||||
# key="unets", paths="C:/base_path/models/unets", "C:/base_path/models/diffusion_models"
|
||||
# ^ in this case, we will want folder_names to be both unets and diffusion_models
|
||||
# key="custom_loader", paths="C:/base_path/models/checkpoints"
|
||||
|
||||
# if the paths are subdirectories of any basepath, use relative paths
|
||||
paths: list[Path] = list(map(Path, paths))
|
||||
self.contents.append(self._modify_model_paths(key, paths, supported_extensions))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.contents)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.contents)
|
||||
def __iter__(self) -> typing.Generator[str]:
|
||||
for model_paths in self.contents:
|
||||
for folder_name in model_paths.folder_names:
|
||||
yield folder_name
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self.contents[key]
|
||||
to_remove: list[AbstractPaths] = []
|
||||
for model_paths in self.contents:
|
||||
if model_paths.has_folder_name(key):
|
||||
to_remove.append(model_paths)
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self.contents
|
||||
for model_paths in to_remove:
|
||||
self.contents.remove(model_paths)
|
||||
|
||||
def __contains__(self, item: str):
|
||||
return any(candidate.has_folder_name(item) for candidate in self.contents)
|
||||
|
||||
def items(self):
|
||||
return self.contents.items()
|
||||
items_view = {
|
||||
folder_name: self[folder_name] for folder_name in self.keys()
|
||||
}
|
||||
return items_view.items()
|
||||
|
||||
def values(self):
|
||||
return self.contents.values()
|
||||
return [self[folder_name] for folder_name in self.keys()]
|
||||
|
||||
def keys(self):
|
||||
return self.contents.keys()
|
||||
return [x for x in self]
|
||||
|
||||
def get(self, key, __default=None):
|
||||
return self.contents.get(key, __default)
|
||||
for candidate in self.contents:
|
||||
if candidate.has_folder_name(key):
|
||||
return FolderPathsTuple(key, parent=weakref.ref(self))
|
||||
if __default is not None:
|
||||
raise ValueError("get with default is not supported")
|
||||
|
||||
|
||||
class SaveImagePathResponse(NamedTuple):
|
||||
class SaveImagePathTuple(NamedTuple):
|
||||
full_output_folder: str
|
||||
filename: str
|
||||
counter: int
|
||||
|
||||
@ -1,22 +1,44 @@
|
||||
import sys
|
||||
from functools import wraps
|
||||
|
||||
def create_module_properties():
|
||||
properties = {}
|
||||
patched_modules = set()
|
||||
|
||||
def module_property(func):
|
||||
"""Decorator to turn module functions into properties.
|
||||
Function names must be prefixed with an underscore."""
|
||||
module = sys.modules[func.__module__]
|
||||
def patch_module(module):
|
||||
if module in patched_modules:
|
||||
return
|
||||
|
||||
def base_getattr(name):
|
||||
raise AttributeError(
|
||||
f"module '{module.__name__}' has no attribute '{name}'")
|
||||
def base_getattr(name):
|
||||
raise AttributeError(f"module '{module.__name__}' has no attribute '{name}'")
|
||||
|
||||
old_getattr = getattr(module, '__getattr__', base_getattr)
|
||||
old_getattr = getattr(module, '__getattr__', base_getattr)
|
||||
|
||||
def new_getattr(name):
|
||||
if f'_{name}' == func.__name__:
|
||||
return func()
|
||||
else:
|
||||
return old_getattr(name)
|
||||
def new_getattr(name):
|
||||
if name in properties:
|
||||
return properties[name]()
|
||||
else:
|
||||
return old_getattr(name)
|
||||
|
||||
module.__getattr__ = new_getattr
|
||||
return func
|
||||
module.__getattr__ = new_getattr
|
||||
patched_modules.add(module)
|
||||
|
||||
class ModuleProperties:
|
||||
@staticmethod
|
||||
def getter(func):
|
||||
@wraps(func)
|
||||
def wrapper():
|
||||
return func()
|
||||
|
||||
name = func.__name__
|
||||
if name.startswith('_'):
|
||||
properties[name[1:]] = wrapper
|
||||
else:
|
||||
raise ValueError("Property function names must start with an underscore")
|
||||
|
||||
module = sys.modules[func.__module__]
|
||||
patch_module(module)
|
||||
|
||||
return wrapper
|
||||
|
||||
return ModuleProperties()
|
||||
10
comfy/component_model/outputs_types.py
Normal file
10
comfy/component_model/outputs_types.py
Normal file
@ -0,0 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
# for nodes that return:
|
||||
# { "ui" : { "some_field": Any, "other": Any }}
|
||||
# the outputs dict will be
|
||||
# (the node id)
|
||||
# { "1": { "some_field": Any, "other": Any }}
|
||||
OutputsDict = dict[str, dict[str, typing.Any]]
|
||||
13
comfy/component_model/platform_path.py
Normal file
13
comfy/component_model/platform_path.py
Normal file
@ -0,0 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import PurePosixPath, Path, PosixPath
|
||||
|
||||
|
||||
def construct_path(*args) -> PurePosixPath | Path:
|
||||
if len(args) > 0 and args[0] is not None and isinstance(args[0], str) and args[0].startswith("/"):
|
||||
try:
|
||||
return PosixPath(*args)
|
||||
except Exception:
|
||||
return PurePosixPath(*args)
|
||||
else:
|
||||
return Path(*args)
|
||||
@ -8,13 +8,15 @@ from typing import Tuple
|
||||
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from .outputs_types import OutputsDict
|
||||
|
||||
QueueTuple = Tuple[float, str, dict, dict, list]
|
||||
MAXIMUM_HISTORY_SIZE = 10000
|
||||
|
||||
|
||||
class TaskInvocation(NamedTuple):
|
||||
item_id: int | str
|
||||
outputs: dict
|
||||
outputs: OutputsDict
|
||||
status: Optional[ExecutionStatus]
|
||||
|
||||
|
||||
|
||||
@ -10,11 +10,12 @@ from aio_pika.patterns import JsonRPC
|
||||
from aiohttp import web
|
||||
from aiormq import AMQPConnectionError
|
||||
|
||||
from .executors import ContextVarExecutor
|
||||
from .distributed_progress import DistributedExecutorToClientProgress
|
||||
from .distributed_types import RpcRequest, RpcReply
|
||||
from .process_pool_executor import ProcessPoolExecutor
|
||||
from ..client.embedded_comfy_client import EmbeddedComfyClient
|
||||
from ..cmd.main_pre import tracer
|
||||
from ..component_model.executor_types import Executor
|
||||
from ..component_model.queue_types import ExecutionStatus
|
||||
|
||||
|
||||
@ -28,7 +29,7 @@ class DistributedPromptWorker:
|
||||
queue_name: str = "comfyui",
|
||||
health_check_port: int = 9090,
|
||||
loop: Optional[AbstractEventLoop] = None,
|
||||
executor: Optional[Executor] = None):
|
||||
executor: Optional[ContextVarExecutor | ProcessPoolExecutor] = None):
|
||||
self._rpc = None
|
||||
self._channel = None
|
||||
self._exit_stack = AsyncExitStack()
|
||||
|
||||
24
comfy/distributed/executors.py
Normal file
24
comfy/distributed/executors.py
Normal file
@ -0,0 +1,24 @@
|
||||
import concurrent
|
||||
import contextvars
|
||||
import typing
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from functools import partial
|
||||
|
||||
__version__ = '0.0.1'
|
||||
|
||||
from .process_pool_executor import ProcessPoolExecutor
|
||||
|
||||
|
||||
class ContextVarExecutor(ThreadPoolExecutor):
|
||||
|
||||
def submit(self, fn: typing.Callable, *args, **kwargs) -> Future:
|
||||
ctx = contextvars.copy_context() # type: contextvars.Context
|
||||
|
||||
return super().submit(partial(ctx.run, partial(fn, *args, **kwargs)))
|
||||
|
||||
|
||||
class ContextVarProcessPoolExecutor(ProcessPoolExecutor):
|
||||
|
||||
def submit(self, fn, /, *args, **kwargs) -> concurrent.futures.Future:
|
||||
# TODO: serialize the "comfyui_execution_context"
|
||||
pass
|
||||
@ -6,6 +6,7 @@ from dataclasses import dataclass, replace
|
||||
from typing import Optional, Final
|
||||
|
||||
from .component_model.executor_types import ExecutorToClientProgress
|
||||
from .component_model.folder_path_types import FolderNames
|
||||
from .distributed.server_stub import ServerStub
|
||||
|
||||
_current_context: Final[ContextVar] = ContextVar("comfyui_execution_context")
|
||||
@ -14,23 +15,21 @@ _current_context: Final[ContextVar] = ContextVar("comfyui_execution_context")
|
||||
@dataclass(frozen=True)
|
||||
class ExecutionContext:
|
||||
server: ExecutorToClientProgress
|
||||
folder_names_and_paths: FolderNames
|
||||
node_id: Optional[str] = None
|
||||
task_id: Optional[str] = None
|
||||
inference_mode: bool = True
|
||||
|
||||
|
||||
_empty_execution_context: Final[ExecutionContext] = ExecutionContext(server=ServerStub())
|
||||
_current_context.set(ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames()))
|
||||
|
||||
|
||||
def current_execution_context() -> ExecutionContext:
|
||||
try:
|
||||
return _current_context.get()
|
||||
except LookupError:
|
||||
return _empty_execution_context
|
||||
return _current_context.get()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def new_execution_context(ctx: ExecutionContext):
|
||||
def _new_execution_context(ctx: ExecutionContext):
|
||||
token = _current_context.set(ctx)
|
||||
try:
|
||||
yield ctx
|
||||
@ -39,8 +38,24 @@ def new_execution_context(ctx: ExecutionContext):
|
||||
|
||||
|
||||
@contextmanager
|
||||
def context_execute_node(node_id: str, prompt_id: str):
|
||||
def context_folder_names_and_paths(folder_names_and_paths: FolderNames):
|
||||
current_ctx = current_execution_context()
|
||||
new_ctx = replace(current_ctx, node_id=node_id, task_id=prompt_id)
|
||||
with new_execution_context(new_ctx):
|
||||
new_ctx = replace(current_ctx, folder_names_and_paths=folder_names_and_paths)
|
||||
with _new_execution_context(new_ctx):
|
||||
yield new_ctx
|
||||
|
||||
|
||||
@contextmanager
|
||||
def context_execute_prompt(server: ExecutorToClientProgress, prompt_id: str, inference_mode: bool = True):
|
||||
current_ctx = current_execution_context()
|
||||
new_ctx = replace(current_ctx, server=server, task_id=prompt_id, inference_mode=inference_mode)
|
||||
with _new_execution_context(new_ctx):
|
||||
yield new_ctx
|
||||
|
||||
|
||||
@contextmanager
|
||||
def context_execute_node(node_id: str):
|
||||
current_ctx = current_execution_context()
|
||||
new_ctx = replace(current_ctx, node_id=node_id)
|
||||
with _new_execution_context(new_ctx):
|
||||
yield new_ctx
|
||||
|
||||
@ -198,10 +198,6 @@ Visit the repository, accept the terms, and then do one of the following:
|
||||
- Login to Hugging Face in your terminal using `huggingface-cli login`
|
||||
""")
|
||||
raise exc_info
|
||||
finally:
|
||||
# a path was found for any reason, so we should invalidate the cache
|
||||
if path is not None:
|
||||
folder_paths.invalidate_cache(folder_name)
|
||||
if path is None:
|
||||
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found, and no download candidates matched for the filename.")
|
||||
return path
|
||||
@ -504,7 +500,6 @@ def add_known_models(folder_name: str, known_models: Optional[List[Downloadable]
|
||||
|
||||
pre_existing = frozenset(known_models)
|
||||
known_models.extend([model for model in models if model not in pre_existing])
|
||||
folder_paths.invalidate_cache(folder_name)
|
||||
return known_models
|
||||
|
||||
|
||||
|
||||
@ -2,7 +2,6 @@ import torch
|
||||
|
||||
import os
|
||||
import json
|
||||
import hashlib
|
||||
import math
|
||||
import random
|
||||
import logging
|
||||
@ -1654,7 +1653,7 @@ class LoadImageMask:
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||
return {"required":
|
||||
{"image": (sorted(files), {"image_upload": True}),
|
||||
{"image": (natsorted(files), {"image_upload": True}),
|
||||
"channel": (s._color_channels, ), }
|
||||
}
|
||||
|
||||
|
||||
@ -29,9 +29,9 @@ import sys
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from pickle import UnpicklingError
|
||||
from typing import Optional, Any
|
||||
|
||||
import accelerate
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
import torch
|
||||
@ -65,7 +65,7 @@ def load_torch_file(ckpt: str, safe_load=False, device=None):
|
||||
if ckpt is None:
|
||||
raise FileNotFoundError("the checkpoint was not found")
|
||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||
sd = safetensors.torch.load_file(ckpt, device=device.type)
|
||||
sd = safetensors.torch.load_file(Path(ckpt).resolve(strict=True), device=device.type)
|
||||
elif ckpt.lower().endswith("index.json"):
|
||||
# from accelerate
|
||||
index_filename = ckpt
|
||||
@ -81,20 +81,29 @@ def load_torch_file(ckpt: str, safe_load=False, device=None):
|
||||
for checkpoint_file in checkpoint_files:
|
||||
sd.update(safetensors.torch.load_file(str(checkpoint_file), device=device.type))
|
||||
else:
|
||||
if safe_load:
|
||||
if not 'weights_only' in torch.load.__code__.co_varnames:
|
||||
logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
|
||||
safe_load = False
|
||||
if safe_load:
|
||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
|
||||
else:
|
||||
pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle)
|
||||
if "global_step" in pl_sd:
|
||||
logging.debug(f"Global Step: {pl_sd['global_step']}")
|
||||
if "state_dict" in pl_sd:
|
||||
sd = pl_sd["state_dict"]
|
||||
else:
|
||||
sd = pl_sd
|
||||
try:
|
||||
if safe_load:
|
||||
if not 'weights_only' in torch.load.__code__.co_varnames:
|
||||
logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
|
||||
safe_load = False
|
||||
if safe_load:
|
||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
|
||||
else:
|
||||
pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle)
|
||||
if "global_step" in pl_sd:
|
||||
logging.debug(f"Global Step: {pl_sd['global_step']}")
|
||||
if "state_dict" in pl_sd:
|
||||
sd = pl_sd["state_dict"]
|
||||
else:
|
||||
sd = pl_sd
|
||||
except UnpicklingError as exc_info:
|
||||
try:
|
||||
# wrong extension is most likely, try to load as safetensors anyway
|
||||
sd = safetensors.torch.load_file(Path(ckpt).resolve(strict=True), device=device.type)
|
||||
return sd
|
||||
except Exception:
|
||||
exc_info.add_note(f"The checkpoint at {ckpt} could not be loaded as a safetensor nor a torch checkpoint. The file at the path is corrupted or unexpected. Try deleting it and downloading it again")
|
||||
raise exc_info
|
||||
return sd
|
||||
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ from transformers.models.nllb.tokenization_nllb import \
|
||||
FAIRSEQ_LANGUAGE_CODES as tokenization_nllb_FAIRSEQ_LANGUAGE_CODES
|
||||
|
||||
from comfy.cmd import folder_paths
|
||||
from comfy.component_model.folder_path_types import SaveImagePathResponse
|
||||
from comfy.component_model.folder_path_types import SaveImagePathTuple
|
||||
from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
|
||||
from comfy.language.language_types import GENERATION_KWARGS_TYPE, GENERATION_KWARGS_TYPE_NAME, TOKENS_TYPE, \
|
||||
TOKENS_TYPE_NAME, LanguageModel
|
||||
@ -397,7 +397,7 @@ class SaveString(CustomNode):
|
||||
OUTPUT_NODE = True
|
||||
RETURN_TYPES = ()
|
||||
|
||||
def get_save_path(self, filename_prefix) -> SaveImagePathResponse:
|
||||
def get_save_path(self, filename_prefix) -> SaveImagePathTuple:
|
||||
return folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory(), 0, 0)
|
||||
|
||||
def execute(self, value: str | list[str], filename_prefix: str, extension: str = ".json"):
|
||||
|
||||
@ -1,21 +0,0 @@
|
||||
import torch
|
||||
|
||||
class TorchCompileModel:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def patch(self, model):
|
||||
m = model.clone()
|
||||
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model")))
|
||||
return (m, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TorchCompileModel": TorchCompileModel,
|
||||
}
|
||||
10
main.py
10
main.py
@ -2,13 +2,15 @@ import asyncio
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
if __name__ == "__main__":
|
||||
from comfy.cmd.folder_paths_pre import set_base_path
|
||||
from comfy.component_model.folder_path_types import FolderNames
|
||||
|
||||
if __name__ == "__main__":
|
||||
warnings.warn("main.py is deprecated. Start comfyui by installing the package through the instructions in the README, not by cloning the repository.", DeprecationWarning)
|
||||
this_file_parent_dir = Path(__file__).parent
|
||||
set_base_path(str(this_file_parent_dir))
|
||||
|
||||
from comfy.cmd.main import main
|
||||
from comfy.cmd.folder_paths import folder_names_and_paths # type: FolderNames
|
||||
fn: FolderNames = folder_names_and_paths
|
||||
fn.base_paths.clear()
|
||||
fn.base_paths.append(this_file_parent_dir)
|
||||
|
||||
asyncio.run(main(from_script_dir=this_file_parent_dir))
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import asyncio
|
||||
import logging
|
||||
logging.basicConfig(level=logging.ERROR)
|
||||
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Callable
|
||||
|
||||
import jwt
|
||||
@ -15,10 +17,12 @@ from comfy.component_model.executor_types import Executor
|
||||
from comfy.component_model.make_mutable import make_mutable
|
||||
from comfy.component_model.queue_types import QueueItem, QueueTuple, TaskInvocation, NamedQueueTuple, ExecutionStatus
|
||||
from comfy.distributed.distributed_prompt_worker import DistributedPromptWorker
|
||||
from comfy.distributed.executors import ContextVarExecutor
|
||||
from comfy.distributed.process_pool_executor import ProcessPoolExecutor
|
||||
from comfy.distributed.server_stub import ServerStub
|
||||
|
||||
|
||||
|
||||
def create_test_prompt() -> QueueItem:
|
||||
from comfy.cmd.execution import validate_prompt
|
||||
|
||||
@ -37,8 +41,11 @@ async def test_sign_jwt_auth_none():
|
||||
assert user_token["sub"] == client_id
|
||||
|
||||
|
||||
_executor_factories: tuple[Executor] = (ContextVarExecutor,)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("executor_factory", (ThreadPoolExecutor, ProcessPoolExecutor,))
|
||||
@pytest.mark.parametrize("executor_factory", _executor_factories)
|
||||
async def test_basic_queue_worker(executor_factory: Callable[..., Executor]) -> None:
|
||||
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
||||
params = rabbitmq.get_connection_params()
|
||||
@ -72,7 +79,7 @@ async def test_distributed_prompt_queues_same_process():
|
||||
frontend.put(test_prompt)
|
||||
|
||||
# start a worker thread
|
||||
thread_pool = ThreadPoolExecutor(max_workers=1)
|
||||
thread_pool = ContextVarExecutor(max_workers=1)
|
||||
|
||||
async def in_thread():
|
||||
incoming, incoming_prompt_id = worker.get()
|
||||
@ -127,7 +134,7 @@ async def check_health(url: str, max_retries: int = 5, retry_delay: float = 1.0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("executor_factory", (ThreadPoolExecutor, ProcessPoolExecutor,))
|
||||
@pytest.mark.parametrize("executor_factory", _executor_factories)
|
||||
async def test_basic_queue_worker_with_health_check(executor_factory):
|
||||
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
||||
params = rabbitmq.get_connection_params()
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import uuid
|
||||
from contextvars import ContextVar
|
||||
from typing import Dict, Optional
|
||||
@ -73,7 +74,8 @@ class ComfyClient:
|
||||
prompt_id = str(uuid.uuid4())
|
||||
try:
|
||||
outputs = await self.embedded_client.queue_prompt(graph.finalize(), prompt_id=prompt_id)
|
||||
except (RuntimeError, DependencyCycleError):
|
||||
except (RuntimeError, DependencyCycleError) as exc_info:
|
||||
logging.warning("error when queueing prompt", exc_info=exc_info)
|
||||
outputs = {}
|
||||
result = RunResult(prompt_id=prompt_id)
|
||||
result.outputs = outputs
|
||||
|
||||
@ -2,11 +2,13 @@
|
||||
# TODO(yoland): clean up this after I get back down
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from comfy.cmd import folder_paths
|
||||
from comfy.component_model.folder_path_types import FolderNames, ModelPaths
|
||||
from comfy.execution_context import context_folder_names_and_paths
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -40,16 +42,6 @@ def test_add_model_folder_path():
|
||||
assert "/test/path" in folder_paths.get_folder_paths("test_folder")
|
||||
|
||||
|
||||
def test_recursive_search(temp_dir):
|
||||
os.makedirs(os.path.join(temp_dir, "subdir"))
|
||||
open(os.path.join(temp_dir, "file1.txt"), "w").close()
|
||||
open(os.path.join(temp_dir, "subdir", "file2.txt"), "w").close()
|
||||
|
||||
files, dirs = folder_paths.recursive_search(temp_dir)
|
||||
assert set(files) == {"file1.txt", os.path.join("subdir", "file2.txt")}
|
||||
assert len(dirs) == 2 # temp_dir and subdir
|
||||
|
||||
|
||||
def test_filter_files_extensions():
|
||||
files = ["file1.txt", "file2.jpg", "file3.png", "file4.txt"]
|
||||
assert folder_paths.filter_files_extensions(files, [".txt"]) == ["file1.txt", "file4.txt"]
|
||||
@ -57,16 +49,24 @@ def test_filter_files_extensions():
|
||||
assert folder_paths.filter_files_extensions(files, []) == files
|
||||
|
||||
|
||||
@patch("folder_paths.recursive_search")
|
||||
@patch("folder_paths.folder_names_and_paths")
|
||||
def test_get_filename_list(mock_folder_names_and_paths, mock_recursive_search):
|
||||
mock_folder_names_and_paths.__getitem__.return_value = (["/test/path"], {".txt"})
|
||||
mock_recursive_search.return_value = (["file1.txt", "file2.jpg"], {})
|
||||
assert folder_paths.get_filename_list("test_folder") == ["file1.txt"]
|
||||
def test_get_filename_list(temp_dir):
|
||||
base_path = Path(temp_dir)
|
||||
fn = FolderNames(base_paths=[base_path])
|
||||
rel_path = Path("test/path")
|
||||
fn.add(ModelPaths(["test_folder"], additional_relative_directory_paths={rel_path}, supported_extensions={".txt"}))
|
||||
dir_path = base_path / rel_path
|
||||
Path.mkdir(dir_path, parents=True, exist_ok=True)
|
||||
files = ["file1.txt", "file2.jpg"]
|
||||
|
||||
for file in files:
|
||||
Path.touch(dir_path / file, exist_ok=True)
|
||||
|
||||
with context_folder_names_and_paths(fn):
|
||||
assert folder_paths.get_filename_list("test_folder") == ["file1.txt"]
|
||||
|
||||
|
||||
def test_get_save_image_path(temp_dir):
|
||||
with patch("folder_paths.output_directory", temp_dir):
|
||||
with context_folder_names_and_paths(FolderNames(base_paths=[Path(temp_dir)])):
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("test", temp_dir, 100, 100)
|
||||
assert os.path.samefile(full_output_folder, temp_dir)
|
||||
assert filename == "test"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user