diff --git a/extra_model_paths.yaml.example b/extra_model_paths.yaml.example index 846d04dbe..d4fd55c87 100644 --- a/extra_model_paths.yaml.example +++ b/extra_model_paths.yaml.example @@ -30,6 +30,9 @@ a111: # clip_vision: models/clip_vision/ # configs: models/configs/ # controlnet: models/controlnet/ +# diffusion_models: | +# models/diffusion_models +# models/unet # embeddings: models/embeddings/ # loras: models/loras/ # upscale_models: models/upscale_models/ diff --git a/main.py b/main.py index d791a169c..748846b4a 100644 --- a/main.py +++ b/main.py @@ -63,6 +63,7 @@ import threading import gc import logging +import utils.extra_config if os.name == "nt": logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) @@ -85,7 +86,6 @@ if args.windows_standalone_build: pass import comfy.utils -import yaml import execution import server @@ -180,27 +180,6 @@ def cleanup_temp(): shutil.rmtree(temp_dir, ignore_errors=True) -def load_extra_path_config(yaml_path): - with open(yaml_path, 'r') as stream: - config = yaml.safe_load(stream) - for c in config: - conf = config[c] - if conf is None: - continue - base_path = None - if "base_path" in conf: - base_path = conf.pop("base_path") - for x in conf: - for y in conf[x].split("\n"): - if len(y) == 0: - continue - full_path = y - if base_path is not None: - full_path = os.path.join(base_path, full_path) - logging.info("Adding extra search path {} {}".format(x, full_path)) - folder_paths.add_model_folder_path(x, full_path) - - if __name__ == "__main__": if args.temp_directory: temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp") @@ -222,11 +201,11 @@ if __name__ == "__main__": extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") if os.path.isfile(extra_model_paths_config_path): - load_extra_path_config(extra_model_paths_config_path) + utils.extra_config.load_extra_path_config(extra_model_paths_config_path) if args.extra_model_paths_config: for config_path in itertools.chain(*args.extra_model_paths_config): - load_extra_path_config(config_path) + utils.extra_config.load_extra_path_config(config_path) nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes) diff --git a/tests-unit/comfy_test/folder_path_test.py b/tests-unit/comfy_test/folder_path_test.py new file mode 100644 index 000000000..0bbec593b --- /dev/null +++ b/tests-unit/comfy_test/folder_path_test.py @@ -0,0 +1,66 @@ +### 🗻 This file is created through the spirit of Mount Fuji at its peak +# TODO(yoland): clean up this after I get back down +import pytest +import os +import tempfile +from unittest.mock import patch + +import folder_paths + +@pytest.fixture +def temp_dir(): + with tempfile.TemporaryDirectory() as tmpdirname: + yield tmpdirname + + +def test_get_directory_by_type(): + test_dir = "/test/dir" + folder_paths.set_output_directory(test_dir) + assert folder_paths.get_directory_by_type("output") == test_dir + assert folder_paths.get_directory_by_type("invalid") is None + +def test_annotated_filepath(): + assert folder_paths.annotated_filepath("test.txt") == ("test.txt", None) + assert folder_paths.annotated_filepath("test.txt [output]") == ("test.txt", folder_paths.get_output_directory()) + assert folder_paths.annotated_filepath("test.txt [input]") == ("test.txt", folder_paths.get_input_directory()) + assert folder_paths.annotated_filepath("test.txt [temp]") == ("test.txt", folder_paths.get_temp_directory()) + +def test_get_annotated_filepath(): + default_dir = "/default/dir" + assert folder_paths.get_annotated_filepath("test.txt", default_dir) == os.path.join(default_dir, "test.txt") + assert folder_paths.get_annotated_filepath("test.txt [output]") == os.path.join(folder_paths.get_output_directory(), "test.txt") + +def test_add_model_folder_path(): + folder_paths.add_model_folder_path("test_folder", "/test/path") + assert "/test/path" in folder_paths.get_folder_paths("test_folder") + +def test_recursive_search(temp_dir): + os.makedirs(os.path.join(temp_dir, "subdir")) + open(os.path.join(temp_dir, "file1.txt"), "w").close() + open(os.path.join(temp_dir, "subdir", "file2.txt"), "w").close() + + files, dirs = folder_paths.recursive_search(temp_dir) + assert set(files) == {"file1.txt", os.path.join("subdir", "file2.txt")} + assert len(dirs) == 2 # temp_dir and subdir + +def test_filter_files_extensions(): + files = ["file1.txt", "file2.jpg", "file3.png", "file4.txt"] + assert folder_paths.filter_files_extensions(files, [".txt"]) == ["file1.txt", "file4.txt"] + assert folder_paths.filter_files_extensions(files, [".jpg", ".png"]) == ["file2.jpg", "file3.png"] + assert folder_paths.filter_files_extensions(files, []) == files + +@patch("folder_paths.recursive_search") +@patch("folder_paths.folder_names_and_paths") +def test_get_filename_list(mock_folder_names_and_paths, mock_recursive_search): + mock_folder_names_and_paths.__getitem__.return_value = (["/test/path"], {".txt"}) + mock_recursive_search.return_value = (["file1.txt", "file2.jpg"], {}) + assert folder_paths.get_filename_list("test_folder") == ["file1.txt"] + +def test_get_save_image_path(temp_dir): + with patch("folder_paths.output_directory", temp_dir): + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path("test", temp_dir, 100, 100) + assert os.path.samefile(full_output_folder, temp_dir) + assert filename == "test" + assert counter == 1 + assert subfolder == "" + assert filename_prefix == "test" \ No newline at end of file diff --git a/tests-unit/utils/extra_config_test.py b/tests-unit/utils/extra_config_test.py new file mode 100644 index 000000000..f56dd3e2e --- /dev/null +++ b/tests-unit/utils/extra_config_test.py @@ -0,0 +1,69 @@ +import pytest +import yaml +import os +from unittest.mock import Mock, patch, mock_open + +from utils.extra_config import load_extra_path_config +import folder_paths + +@pytest.fixture +def mock_yaml_content(): + return { + 'test_config': { + 'base_path': '~/App/', + 'checkpoints': 'subfolder1', + } + } + +@pytest.fixture +def mock_expanded_home(): + return '/home/user' + +@pytest.fixture +def mock_add_model_folder_path(): + return Mock() + +@pytest.fixture +def mock_expanduser(mock_expanded_home): + def _expanduser(path): + if path.startswith('~/'): + return os.path.join(mock_expanded_home, path[2:]) + return path + return _expanduser + +@pytest.fixture +def mock_yaml_safe_load(mock_yaml_content): + return Mock(return_value=mock_yaml_content) + +@patch('builtins.open', new_callable=mock_open, read_data="dummy file content") +def test_load_extra_model_paths_expands_userpath( + mock_file, + monkeypatch, + mock_add_model_folder_path, + mock_expanduser, + mock_yaml_safe_load, + mock_expanded_home +): + # Attach mocks used by load_extra_path_config + monkeypatch.setattr(folder_paths, 'add_model_folder_path', mock_add_model_folder_path) + monkeypatch.setattr(os.path, 'expanduser', mock_expanduser) + monkeypatch.setattr(yaml, 'safe_load', mock_yaml_safe_load) + + dummy_yaml_file_name = 'dummy_path.yaml' + load_extra_path_config(dummy_yaml_file_name) + + expected_calls = [ + ('checkpoints', os.path.join(mock_expanded_home, 'App', 'subfolder1')), + ] + + assert mock_add_model_folder_path.call_count == len(expected_calls) + + # Check if add_model_folder_path was called with the correct arguments + for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls): + assert actual_call.args == expected_call + + # Check if yaml.safe_load was called + mock_yaml_safe_load.assert_called_once() + + # Check if open was called with the correct file path + mock_file.assert_called_once_with(dummy_yaml_file_name, 'r') diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/utils/extra_config.py b/utils/extra_config.py new file mode 100644 index 000000000..23c2d791c --- /dev/null +++ b/utils/extra_config.py @@ -0,0 +1,25 @@ +import os +import yaml +import folder_paths +import logging + +def load_extra_path_config(yaml_path): + with open(yaml_path, 'r') as stream: + config = yaml.safe_load(stream) + for c in config: + conf = config[c] + if conf is None: + continue + base_path = None + if "base_path" in conf: + base_path = conf.pop("base_path") + base_path = os.path.expanduser(base_path) + for x in conf: + for y in conf[x].split("\n"): + if len(y) == 0: + continue + full_path = y + if base_path is not None: + full_path = os.path.join(base_path, full_path) + logging.info("Adding extra search path {} {}".format(x, full_path)) + folder_paths.add_model_folder_path(x, full_path)