Fix #25 custom nodes which have input paths set at import time will now correctly see a models directory (or similar) that respects the configuration intended by the user

This commit is contained in:
doctorpangloss 2024-11-01 13:40:03 -07:00
parent a13b9d3b54
commit 021d0d4f57
10 changed files with 117 additions and 36 deletions

3
.gitignore vendored
View File

@ -1,3 +1,6 @@
main.exp
main.lib
main.obj
.DS_Store .DS_Store
/[Oo]utput/ /[Oo]utput/
/[Ii]nput/ /[Ii]nput/

View File

@ -124,7 +124,7 @@ class Configuration(dict):
self._observers: List[ConfigObserver] = [] self._observers: List[ConfigObserver] = []
self.config_files = [] self.config_files = []
self.cwd: Optional[str] = None self.cwd: Optional[str] = None
self.base_paths: list[Path] = [] self.base_paths: list[str] = []
self.listen: str = "127.0.0.1" self.listen: str = "127.0.0.1"
self.port: int = 8188 self.port: int = 8188
self.enable_cors_header: Optional[str] = None self.enable_cors_header: Optional[str] = None

View File

@ -37,7 +37,7 @@ def _execute_prompt(
configuration: Configuration | None) -> dict: configuration: Configuration | None) -> dict:
execution_context = current_execution_context() execution_context = current_execution_context()
if len(execution_context.folder_names_and_paths) == 0 or configuration is not None: if len(execution_context.folder_names_and_paths) == 0 or configuration is not None:
init_default_paths(execution_context.folder_names_and_paths, configuration) init_default_paths(execution_context.folder_names_and_paths, configuration, replace_existing=True)
span_context: Context = propagate.extract(span_context) span_context: Context = propagate.extract(span_context)
token = attach(span_context) token = attach(span_context)
try: try:
@ -54,6 +54,8 @@ def __execute_prompt(
progress_handler: ExecutorToClientProgress | None, progress_handler: ExecutorToClientProgress | None,
configuration: Configuration | None) -> dict: configuration: Configuration | None) -> dict:
from .. import options from .. import options
from ..cmd.execution import PromptExecutor
progress_handler = progress_handler or ServerStub() progress_handler = progress_handler or ServerStub()
try: try:
@ -66,8 +68,7 @@ def __execute_prompt(
args.clear() args.clear()
args.update(configuration) args.update(configuration)
from ..cmd.execution import PromptExecutor with tracer.start_as_current_span("Initialize Prompt Executor", context=span_context):
with tracer.start_as_current_span("Initialize Prompt Executor", context=span_context) as span:
prompt_executor = PromptExecutor(progress_handler, lru_size=configuration.cache_lru if configuration is not None else 0) prompt_executor = PromptExecutor(progress_handler, lru_size=configuration.cache_lru if configuration is not None else 0)
prompt_executor.raise_exceptions = True prompt_executor.raise_exceptions = True
_prompt_executor.executor = prompt_executor _prompt_executor.executor = prompt_executor

View File

@ -5,6 +5,7 @@ import mimetypes
import os import os
import time import time
from contextlib import nullcontext from contextlib import nullcontext
from functools import reduce
from pathlib import Path, PurePosixPath from pathlib import Path, PurePosixPath
from typing import Optional, List, Literal from typing import Optional, List, Literal
@ -51,46 +52,60 @@ def _resolve_path_with_compatibility(path: Path | str) -> PurePosixPath | Path:
return Path(path).resolve() return Path(path).resolve()
def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optional[Configuration] = None, create_all_directories=False): def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optional[Configuration] = None, create_all_directories=False, replace_existing=True):
""" """
Populates the folder names and paths object with the default, upstream model directories and custom_nodes directory. Populates the folder names and paths object with the default, upstream model directories and custom_nodes directory.
:param folder_names_and_paths: the object to populate with paths :param folder_names_and_paths: the object to populate with paths
:param configuration: a configuration whose base_paths and other path settings will be used to set the values on this object :param configuration: a configuration whose base_paths and other path settings will be used to set the values on this object
:param create_all_directories: create all the possible directories by calling create_directories() after the object is populated :param create_all_directories: create all the possible directories by calling create_directories() after the object is populated
:param replace_existing: when true, removes existing model paths objects for the built-in folder names; and, replaces the base paths
:return: :return:
""" """
from ..cmd.main_pre import args from ..cmd.main_pre import args
configuration = configuration or args configuration = configuration or args
base_paths = [Path(configuration.cwd) if configuration.cwd is not None else None] + configuration.base_paths base_paths = [Path(configuration.cwd) if configuration.cwd is not None else None] + configuration.base_paths
base_paths = [path for path in base_paths if path is not None] base_paths = [Path(path) for path in base_paths if path is not None]
if len(base_paths) == 0: if len(base_paths) == 0:
base_paths = [Path(os.getcwd())] base_paths = [Path(os.getcwd())]
base_paths = reduce(lambda uniq_list, item: uniq_list.append(item) or uniq_list if item not in uniq_list else uniq_list, base_paths, [])
if replace_existing:
folder_names_and_paths.base_paths.clear()
for base_path in base_paths: for base_path in base_paths:
folder_names_and_paths.add_base_path(base_path) folder_names_and_paths.add_base_path(base_path)
folder_names_and_paths.add(ModelPaths(["checkpoints"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["configs"], additional_absolute_directory_paths={get_package_as_path("comfy.configs")}, supported_extensions={".yaml"}))
folder_names_and_paths.add(ModelPaths(["vae"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["clip"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["loras"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(folder_names=["diffusion_models", "unet"], supported_extensions=set(supported_pt_extensions), folder_names_are_relative_directory_paths_too=True))
folder_names_and_paths.add(ModelPaths(["clip_vision"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["style_models"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["embeddings"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["diffusers"], supported_extensions=set()))
folder_names_and_paths.add(ModelPaths(["vae_approx"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(folder_names=["controlnet", "t2i_adapter"], supported_extensions=set(supported_pt_extensions), folder_names_are_relative_directory_paths_too=True))
folder_names_and_paths.add(ModelPaths(["gligen"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["upscale_models"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["custom_nodes"], folder_name_base_path_subdir=construct_path(""), supported_extensions=set()))
folder_names_and_paths.add(ModelPaths(["hypernetworks"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["photomaker"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["classifiers"], supported_extensions=set()))
folder_names_and_paths.add(ModelPaths(["huggingface"], supported_extensions=set()))
hf_cache_paths = ModelPaths(["huggingface_cache"], supported_extensions=set()) hf_cache_paths = ModelPaths(["huggingface_cache"], supported_extensions=set())
# TODO: explore if there is a better way to do this # TODO: explore if there is a better way to do this
if "HF_HUB_CACHE" in os.environ: if "HF_HUB_CACHE" in os.environ:
hf_cache_paths.additional_absolute_directory_paths.add(os.environ.get("HF_HUB_CACHE")) hf_cache_paths.additional_absolute_directory_paths.add(os.environ.get("HF_HUB_CACHE"))
folder_names_and_paths.add(hf_cache_paths)
model_paths_to_add = [
ModelPaths(["checkpoints"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["configs"], additional_absolute_directory_paths={get_package_as_path("comfy.configs")}, supported_extensions={".yaml"}),
ModelPaths(["vae"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["clip"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["loras"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(folder_names=["diffusion_models", "unet"], supported_extensions=set(supported_pt_extensions), folder_names_are_relative_directory_paths_too=True),
ModelPaths(["clip_vision"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["style_models"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["embeddings"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["diffusers"], supported_extensions=set()),
ModelPaths(["vae_approx"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(folder_names=["controlnet", "t2i_adapter"], supported_extensions=set(supported_pt_extensions), folder_names_are_relative_directory_paths_too=True),
ModelPaths(["gligen"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["upscale_models"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["custom_nodes"], folder_name_base_path_subdir=construct_path(""), supported_extensions=set()),
ModelPaths(["hypernetworks"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["photomaker"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["classifiers"], supported_extensions=set()),
ModelPaths(["huggingface"], supported_extensions=set()),
hf_cache_paths,
]
for model_paths in model_paths_to_add:
if replace_existing:
for folder_name in model_paths.folder_names:
del folder_names_and_paths[folder_name]
folder_names_and_paths.add(model_paths)
if create_all_directories: if create_all_directories:
create_directories(folder_names_and_paths) create_directories(folder_names_and_paths)

View File

@ -17,11 +17,7 @@ supported_pt_extensions: set[str]
# Functions # Functions
def init_default_paths( def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optional[Configuration] = None, create_all_directories: bool = ..., replace_existing: bool = ...): ...
folder_names_and_paths: FolderNames,
configuration: Optional[Configuration] = ...,
create_all_directories: bool = ...
) -> None: ...
def map_legacy(folder_name: str) -> str: ... def map_legacy(folder_name: str) -> str: ...

View File

@ -6,7 +6,7 @@ import os
import typing import typing
import weakref import weakref
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path, PurePosixPath from pathlib import Path
from typing import Any, NamedTuple, Optional, Iterable from typing import Any, NamedTuple, Optional, Iterable
from .platform_path import construct_path from .platform_path import construct_path
@ -452,9 +452,16 @@ class FolderNames:
def __delitem__(self, key): def __delitem__(self, key):
to_remove: list[AbstractPaths] = [] to_remove: list[AbstractPaths] = []
if isinstance(key, str):
folder_names = [key]
else:
iter(key)
folder_names = key
for model_paths in self.contents: for model_paths in self.contents:
if model_paths.has_folder_name(key): for folder_name in folder_names:
to_remove.append(model_paths) if model_paths.has_folder_name(folder_name):
to_remove.append(model_paths)
for model_paths in to_remove: for model_paths in to_remove:
self.contents.remove(model_paths) self.contents.remove(model_paths)

View File

@ -1,5 +1,6 @@
import asyncio import asyncio
import logging import logging
logging.basicConfig(level=logging.ERROR) logging.basicConfig(level=logging.ERROR)
import uuid import uuid
@ -22,7 +23,6 @@ from comfy.distributed.process_pool_executor import ProcessPoolExecutor
from comfy.distributed.server_stub import ServerStub from comfy.distributed.server_stub import ServerStub
def create_test_prompt() -> QueueItem: def create_test_prompt() -> QueueItem:
from comfy.cmd.execution import validate_prompt from comfy.cmd.execution import validate_prompt
@ -41,7 +41,7 @@ async def test_sign_jwt_auth_none():
assert user_token["sub"] == client_id assert user_token["sub"] == client_id
_executor_factories: tuple[Executor] = (ContextVarExecutor,) _executor_factories: tuple[Executor] = (ContextVarExecutor, ProcessPoolExecutor)
@pytest.mark.asyncio @pytest.mark.asyncio

0
tests/issues/__init__.py Normal file
View File

View File

@ -0,0 +1,36 @@
import os.path
import tempfile
from importlib.resources import files
import pytest
from comfy.api.components.schema.prompt import Prompt
from comfy.cli_args_types import Configuration
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
_TEST_WORKFLOW = {
"0": {
"inputs": {},
"class_type": "TestPath",
"_meta": {
"title": ""
}
}
}
@pytest.mark.asyncio
async def test_respect_cwd_param():
with tempfile.TemporaryDirectory() as tmp_dir:
cwd = str(tmp_dir)
config = Configuration(cwd=cwd)
# for finding the custom nodes
config.base_paths = [files(__package__)]
from comfy.cmd.folder_paths import models_dir
assert os.path.commonpath([os.getcwd(), models_dir]) == os.getcwd(), "at the time models_dir is accessed, the cwd should be the actual cwd, since there is no other configuration"
client = EmbeddedComfyClient(config)
prompt = Prompt.validate(_TEST_WORKFLOW)
outputs = await client.queue_prompt_api(prompt)
path_as_imported = outputs.outputs["0"]["path"][0]
assert os.path.commonpath([path_as_imported, cwd]) == cwd, "at the time the node is imported, the cwd should be the temporary directory"

View File

@ -0,0 +1,23 @@
import os
from comfy.cmd import folder_paths
from comfy.nodes.package_typing import CustomNode, InputTypes, FunctionReturnsUIVariables
TEST_PATH = os.path.join(folder_paths.models_dir, "test", "path")
class TestPath(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {"required": {}}
FUNCTION = "execute"
OUTPUT_NODE = True
def execute(self) -> FunctionReturnsUIVariables:
return {"ui": {"path": [TEST_PATH]}}
NODE_CLASS_MAPPINGS = {
"TestPath": TestPath
}