diff --git a/folder_paths.py b/folder_paths.py index 92e8df3cf..eb9de1408 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -121,6 +121,11 @@ def set_temp_directory(temp_dir: str) -> None: def set_input_directory(input_dir: str) -> None: global input_directory + try: + os.makedirs(input_dir, exist_ok=True) + except OSError: + logging.exception("Failed to create input directory: %s", input_dir) + raise input_directory = input_dir def get_output_directory() -> str: diff --git a/tests/test_folder_paths.py b/tests/test_folder_paths.py new file mode 100644 index 000000000..163b58a21 --- /dev/null +++ b/tests/test_folder_paths.py @@ -0,0 +1,30 @@ +import pytest + +import folder_paths + + +def test_set_input_directory_creates_missing_directory(tmp_path): + original_input_directory = folder_paths.get_input_directory() + custom_input_directory = tmp_path / "custom-input" + + try: + folder_paths.set_input_directory(str(custom_input_directory)) + + assert custom_input_directory.is_dir() + finally: + folder_paths.set_input_directory(original_input_directory) + + +def test_set_input_directory_keeps_original_when_creation_fails(tmp_path, monkeypatch): + original_input_directory = folder_paths.get_input_directory() + custom_input_directory = tmp_path / "custom-input" + + def fail_to_create_directory(path, exist_ok=False): + raise OSError("create failed") + + monkeypatch.setattr(folder_paths.os, "makedirs", fail_to_create_directory) + + with pytest.raises(OSError, match="create failed"): + folder_paths.set_input_directory(str(custom_input_directory)) + + assert folder_paths.get_input_directory() == original_input_directory