Fix #25 custom nodes which have input paths set at import time will now correctly see a models directory (or similar) that respects the configuration intended by the user

This commit is contained in:
doctorpangloss 2024-11-01 13:40:03 -07:00
parent a13b9d3b54
commit 021d0d4f57
10 changed files with 117 additions and 36 deletions

3
.gitignore vendored
View File

@ -1,3 +1,6 @@
main.exp
main.lib
main.obj
.DS_Store
/[Oo]utput/
/[Ii]nput/

View File

@ -124,7 +124,7 @@ class Configuration(dict):
self._observers: List[ConfigObserver] = []
self.config_files = []
self.cwd: Optional[str] = None
self.base_paths: list[Path] = []
self.base_paths: list[str] = []
self.listen: str = "127.0.0.1"
self.port: int = 8188
self.enable_cors_header: Optional[str] = None

View File

@ -37,7 +37,7 @@ def _execute_prompt(
configuration: Configuration | None) -> dict:
execution_context = current_execution_context()
if len(execution_context.folder_names_and_paths) == 0 or configuration is not None:
init_default_paths(execution_context.folder_names_and_paths, configuration)
init_default_paths(execution_context.folder_names_and_paths, configuration, replace_existing=True)
span_context: Context = propagate.extract(span_context)
token = attach(span_context)
try:
@ -54,6 +54,8 @@ def __execute_prompt(
progress_handler: ExecutorToClientProgress | None,
configuration: Configuration | None) -> dict:
from .. import options
from ..cmd.execution import PromptExecutor
progress_handler = progress_handler or ServerStub()
try:
@ -66,8 +68,7 @@ def __execute_prompt(
args.clear()
args.update(configuration)
from ..cmd.execution import PromptExecutor
with tracer.start_as_current_span("Initialize Prompt Executor", context=span_context) as span:
with tracer.start_as_current_span("Initialize Prompt Executor", context=span_context):
prompt_executor = PromptExecutor(progress_handler, lru_size=configuration.cache_lru if configuration is not None else 0)
prompt_executor.raise_exceptions = True
_prompt_executor.executor = prompt_executor

View File

@ -5,6 +5,7 @@ import mimetypes
import os
import time
from contextlib import nullcontext
from functools import reduce
from pathlib import Path, PurePosixPath
from typing import Optional, List, Literal
@ -51,46 +52,60 @@ def _resolve_path_with_compatibility(path: Path | str) -> PurePosixPath | Path:
return Path(path).resolve()
def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optional[Configuration] = None, create_all_directories=False):
def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optional[Configuration] = None, create_all_directories=False, replace_existing=True):
"""
Populates the folder names and paths object with the default, upstream model directories and custom_nodes directory.
:param folder_names_and_paths: the object to populate with paths
:param configuration: a configuration whose base_paths and other path settings will be used to set the values on this object
:param create_all_directories: create all the possible directories by calling create_directories() after the object is populated
:param replace_existing: when true, removes existing model paths objects for the built-in folder names; and, replaces the base paths
:return:
"""
from ..cmd.main_pre import args
configuration = configuration or args
base_paths = [Path(configuration.cwd) if configuration.cwd is not None else None] + configuration.base_paths
base_paths = [path for path in base_paths if path is not None]
base_paths = [Path(path) for path in base_paths if path is not None]
if len(base_paths) == 0:
base_paths = [Path(os.getcwd())]
base_paths = reduce(lambda uniq_list, item: uniq_list.append(item) or uniq_list if item not in uniq_list else uniq_list, base_paths, [])
if replace_existing:
folder_names_and_paths.base_paths.clear()
for base_path in base_paths:
folder_names_and_paths.add_base_path(base_path)
folder_names_and_paths.add(ModelPaths(["checkpoints"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["configs"], additional_absolute_directory_paths={get_package_as_path("comfy.configs")}, supported_extensions={".yaml"}))
folder_names_and_paths.add(ModelPaths(["vae"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["clip"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["loras"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(folder_names=["diffusion_models", "unet"], supported_extensions=set(supported_pt_extensions), folder_names_are_relative_directory_paths_too=True))
folder_names_and_paths.add(ModelPaths(["clip_vision"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["style_models"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["embeddings"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["diffusers"], supported_extensions=set()))
folder_names_and_paths.add(ModelPaths(["vae_approx"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(folder_names=["controlnet", "t2i_adapter"], supported_extensions=set(supported_pt_extensions), folder_names_are_relative_directory_paths_too=True))
folder_names_and_paths.add(ModelPaths(["gligen"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["upscale_models"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["custom_nodes"], folder_name_base_path_subdir=construct_path(""), supported_extensions=set()))
folder_names_and_paths.add(ModelPaths(["hypernetworks"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["photomaker"], supported_extensions=set(supported_pt_extensions)))
folder_names_and_paths.add(ModelPaths(["classifiers"], supported_extensions=set()))
folder_names_and_paths.add(ModelPaths(["huggingface"], supported_extensions=set()))
hf_cache_paths = ModelPaths(["huggingface_cache"], supported_extensions=set())
# TODO: explore if there is a better way to do this
if "HF_HUB_CACHE" in os.environ:
hf_cache_paths.additional_absolute_directory_paths.add(os.environ.get("HF_HUB_CACHE"))
folder_names_and_paths.add(hf_cache_paths)
model_paths_to_add = [
ModelPaths(["checkpoints"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["configs"], additional_absolute_directory_paths={get_package_as_path("comfy.configs")}, supported_extensions={".yaml"}),
ModelPaths(["vae"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["clip"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["loras"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(folder_names=["diffusion_models", "unet"], supported_extensions=set(supported_pt_extensions), folder_names_are_relative_directory_paths_too=True),
ModelPaths(["clip_vision"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["style_models"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["embeddings"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["diffusers"], supported_extensions=set()),
ModelPaths(["vae_approx"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(folder_names=["controlnet", "t2i_adapter"], supported_extensions=set(supported_pt_extensions), folder_names_are_relative_directory_paths_too=True),
ModelPaths(["gligen"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["upscale_models"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["custom_nodes"], folder_name_base_path_subdir=construct_path(""), supported_extensions=set()),
ModelPaths(["hypernetworks"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["photomaker"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["classifiers"], supported_extensions=set()),
ModelPaths(["huggingface"], supported_extensions=set()),
hf_cache_paths,
]
for model_paths in model_paths_to_add:
if replace_existing:
for folder_name in model_paths.folder_names:
del folder_names_and_paths[folder_name]
folder_names_and_paths.add(model_paths)
if create_all_directories:
create_directories(folder_names_and_paths)

View File

@ -17,11 +17,7 @@ supported_pt_extensions: set[str]
# Functions
def init_default_paths(
folder_names_and_paths: FolderNames,
configuration: Optional[Configuration] = ...,
create_all_directories: bool = ...
) -> None: ...
def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optional[Configuration] = None, create_all_directories: bool = ..., replace_existing: bool = ...): ...
def map_legacy(folder_name: str) -> str: ...

View File

@ -6,7 +6,7 @@ import os
import typing
import weakref
from abc import ABC, abstractmethod
from pathlib import Path, PurePosixPath
from pathlib import Path
from typing import Any, NamedTuple, Optional, Iterable
from .platform_path import construct_path
@ -452,9 +452,16 @@ class FolderNames:
def __delitem__(self, key):
to_remove: list[AbstractPaths] = []
if isinstance(key, str):
folder_names = [key]
else:
iter(key)
folder_names = key
for model_paths in self.contents:
if model_paths.has_folder_name(key):
to_remove.append(model_paths)
for folder_name in folder_names:
if model_paths.has_folder_name(folder_name):
to_remove.append(model_paths)
for model_paths in to_remove:
self.contents.remove(model_paths)

View File

@ -1,5 +1,6 @@
import asyncio
import logging
logging.basicConfig(level=logging.ERROR)
import uuid
@ -22,7 +23,6 @@ from comfy.distributed.process_pool_executor import ProcessPoolExecutor
from comfy.distributed.server_stub import ServerStub
def create_test_prompt() -> QueueItem:
from comfy.cmd.execution import validate_prompt
@ -41,7 +41,7 @@ async def test_sign_jwt_auth_none():
assert user_token["sub"] == client_id
_executor_factories: tuple[Executor] = (ContextVarExecutor,)
_executor_factories: tuple[Executor] = (ContextVarExecutor, ProcessPoolExecutor)
@pytest.mark.asyncio

0
tests/issues/__init__.py Normal file
View File

View File

@ -0,0 +1,36 @@
import os.path
import tempfile
from importlib.resources import files
import pytest
from comfy.api.components.schema.prompt import Prompt
from comfy.cli_args_types import Configuration
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
_TEST_WORKFLOW = {
"0": {
"inputs": {},
"class_type": "TestPath",
"_meta": {
"title": ""
}
}
}
@pytest.mark.asyncio
async def test_respect_cwd_param():
with tempfile.TemporaryDirectory() as tmp_dir:
cwd = str(tmp_dir)
config = Configuration(cwd=cwd)
# for finding the custom nodes
config.base_paths = [files(__package__)]
from comfy.cmd.folder_paths import models_dir
assert os.path.commonpath([os.getcwd(), models_dir]) == os.getcwd(), "at the time models_dir is accessed, the cwd should be the actual cwd, since there is no other configuration"
client = EmbeddedComfyClient(config)
prompt = Prompt.validate(_TEST_WORKFLOW)
outputs = await client.queue_prompt_api(prompt)
path_as_imported = outputs.outputs["0"]["path"][0]
assert os.path.commonpath([path_as_imported, cwd]) == cwd, "at the time the node is imported, the cwd should be the temporary directory"

View File

@ -0,0 +1,23 @@
import os
from comfy.cmd import folder_paths
from comfy.nodes.package_typing import CustomNode, InputTypes, FunctionReturnsUIVariables
TEST_PATH = os.path.join(folder_paths.models_dir, "test", "path")
class TestPath(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {"required": {}}
FUNCTION = "execute"
OUTPUT_NODE = True
def execute(self) -> FunctionReturnsUIVariables:
return {"ui": {"path": [TEST_PATH]}}
NODE_CLASS_MAPPINGS = {
"TestPath": TestPath
}