Fix tests

This commit is contained in:
doctorpangloss 2024-10-29 12:29:02 -07:00
parent 4a13766d14
commit b42e59d602
3 changed files with 105 additions and 26 deletions

View File

@ -5,7 +5,7 @@ import mimetypes
import os
import time
from contextlib import nullcontext
from pathlib import Path
from pathlib import Path, PurePosixPath
from typing import Optional, List, Literal
from ..cli_args_types import Configuration
@ -19,6 +19,8 @@ from ..execution_context import current_execution_context
_module_properties = create_module_properties()
logger = logging.getLogger(__name__)
@_module_properties.getter
def _supported_pt_extensions() -> frozenset[str]:
@ -36,7 +38,25 @@ def _base_path():
return _folder_names_and_paths().base_paths[0]
def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optional[Configuration] = None):
def _resolve_path_with_compatibility(path: Path | str) -> PurePosixPath | Path:
"""
Absolute posix style paths (aka, paths starting with `/`) are always returned as-is, otherwise this is resolved.
:param path: a path or string to a path
:return: the resolved path
"""
if isinstance(path, PurePosixPath) and path.is_absolute():
return path
return Path(path).resolve()
def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optional[Configuration] = None, create_all_directories=False):
"""
Populates the folder names and paths object with the default, upstream model directories and custom_nodes directory.
:param folder_names_and_paths: the object to populate with paths
:param configuration: a configuration whose base_paths and other path settings will be used to set the values on this object
:param create_all_directories: create all the possible directories by calling create_directories() after the object is populated
:return:
"""
from ..cmd.main_pre import args
configuration = configuration or args
base_paths = [Path(configuration.cwd) if configuration.cwd is not None else None] + configuration.base_paths
@ -69,7 +89,20 @@ def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optio
if "HF_HUB_CACHE" in os.environ:
hf_cache_paths.additional_absolute_directory_paths.add(os.environ.get("HF_HUB_CACHE"))
folder_names_and_paths.add(hf_cache_paths)
create_directories(folder_names_and_paths)
if create_all_directories:
create_directories(folder_names_and_paths)
if configuration.output_directory is not None:
folder_names_and_paths.application_paths.output_directory = Path(configuration.output_directory)
if configuration.input_directory is not None:
folder_names_and_paths.application_paths.input_directory = Path(configuration.input_directory)
if configuration.temp_directory is not None:
folder_names_and_paths.application_paths.temp_directory = Path(configuration.temp_directory)
if configuration.user_directory is not None:
folder_names_and_paths.application_paths.user_directory = Path(configuration.user_directory)
@_module_properties.getter
@ -84,22 +117,22 @@ def _models_dir():
@_module_properties.getter
def _user_directory() -> str:
return str(Path(current_execution_context().folder_names_and_paths.application_paths.user_directory).resolve())
return str(_resolve_path_with_compatibility(current_execution_context().folder_names_and_paths.application_paths.user_directory))
@_module_properties.getter
def _temp_directory() -> str:
return str(Path(current_execution_context().folder_names_and_paths.application_paths.temp_directory).resolve())
return str(_resolve_path_with_compatibility(current_execution_context().folder_names_and_paths.application_paths.temp_directory))
@_module_properties.getter
def _input_directory() -> str:
return str(Path(current_execution_context().folder_names_and_paths.application_paths.input_directory).resolve())
return str(_resolve_path_with_compatibility(current_execution_context().folder_names_and_paths.application_paths.input_directory))
@_module_properties.getter
def _output_directory() -> str:
return str(Path(current_execution_context().folder_names_and_paths.application_paths.output_directory).resolve())
return str(_resolve_path_with_compatibility(current_execution_context().folder_names_and_paths.application_paths.output_directory))
@_deprecate_method(version="0.2.3", message="Mapping of previous folder names is already done by other mechanisms.")
@ -121,19 +154,19 @@ def set_input_directory(input_dir: str | Path):
def get_output_directory() -> str:
return str(Path(_folder_names_and_paths().application_paths.output_directory).resolve())
return str(_resolve_path_with_compatibility(_folder_names_and_paths().application_paths.output_directory))
def get_temp_directory() -> str:
return str(Path(_folder_names_and_paths().application_paths.temp_directory).resolve())
return str(_resolve_path_with_compatibility(_folder_names_and_paths().application_paths.temp_directory))
def get_input_directory() -> str:
return str(Path(_folder_names_and_paths().application_paths.input_directory).resolve())
return str(_resolve_path_with_compatibility(_folder_names_and_paths().application_paths.input_directory))
def get_user_directory() -> str:
return str(Path(_folder_names_and_paths().application_paths.user_directory).resolve())
return str(_resolve_path_with_compatibility(_folder_names_and_paths().application_paths.user_directory))
def set_user_directory(user_dir: str | Path) -> None:
@ -223,9 +256,47 @@ def get_folder_paths(folder_name) -> List[str]:
return [path for path in _folder_names_and_paths()[folder_name].paths]
@_deprecate_method(version="0.2.3", message="Not supported")
@_deprecate_method(version="1.0.0", message="Use os.scandir instead.")
def recursive_search(directory, excluded_dir_names=None) -> tuple[list[str], dict[str, float]]:
raise NotImplemented("Unsupported method")
if not os.path.isdir(directory):
return [], {}
if excluded_dir_names is None:
excluded_dir_names = []
result = []
dirs = {}
# Attempt to add the initial directory to dirs with error handling
try:
dirs[directory] = os.path.getmtime(directory)
except FileNotFoundError:
logger.warning(f"Warning: Unable to access {directory}. Skipping this path.")
logger.debug("recursive file list on directory {}".format(directory))
dirpath: str
subdirs: list[str]
filenames: list[str]
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
for file_name in filenames:
try:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path)
except:
logger.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
continue
for d in subdirs:
path: str = os.path.join(dirpath, d)
try:
dirs[path] = os.path.getmtime(path)
except FileNotFoundError:
logger.warning(f"Warning: Unable to access {path}. Skipping this path.")
continue
logger.debug("found {} files".format(len(result)))
return result, dirs
def filter_files_extensions(files, extensions):
@ -343,12 +414,11 @@ def _cache_helper():
init_default_paths(_folder_names_and_paths())
__all__ = [
# Properties (stripped leading underscore)
"supported_pt_extensions", # from _supported_pt_extensions
"extension_mimetypes_cache", # from _extension_mimetypes_cache
"base_path", # from _base_path
"folder_names_and_paths", # from _folder_names_and_paths
"models_dir", # from _models_dir
"supported_pt_extensions",
"extension_mimetypes_cache",
"base_path",
"folder_names_and_paths",
"models_dir",
"user_directory",
"output_directory",
"temp_directory",

View File

@ -6,9 +6,9 @@ import os
import typing
import weakref
from abc import ABC, abstractmethod
from pathlib import Path, PurePosixPath
from typing import Any, NamedTuple, Optional, Iterable
from pathlib import Path
from .platform_path import construct_path
supported_pt_extensions = frozenset(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft', ".index.json"])
@ -288,7 +288,15 @@ class FolderNames:
:param folder_paths_dict: A dictionary
:return: A FolderNames object
"""
raise NotImplementedError()
fn = FolderNames()
for folder_name, (paths, extensions) in folder_paths_dict.items():
fn.add(
ModelPaths(folder_names=[folder_name],
supported_extensions=set(extensions),
additional_relative_directory_paths=set(path for path in paths if not Path(path).is_absolute()),
additional_absolute_directory_paths=set(path for path in paths if Path(path).is_absolute()), folder_names_are_relative_directory_paths_too=False
))
return fn
def __getitem__(self, folder_name) -> FolderPathsTuple:
if not isinstance(folder_name, str) or folder_name is None:

View File

@ -9,6 +9,7 @@ from aiohttp import ClientResponse
from comfy.cmd import folder_paths
from comfy.component_model.folder_path_types import FolderNames, FolderPathsTuple, supported_pt_extensions
from comfy.execution_context import context_folder_names_and_paths
from comfy.model_filemanager import download_model, track_download_progress, \
create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
@ -74,9 +75,9 @@ async def test_download_model_success(temp_dir):
fake_paths = FolderNames.from_dict({'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)})
with patch('comfy.model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'model.sft')), \
with context_folder_names_and_paths(fake_paths), \
patch('comfy.model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'model.sft')), \
patch('comfy.model_filemanager.check_file_exists', return_value=None), \
patch('comfy.cmd.folder_paths.folder_names_and_paths', fake_paths), \
patch('time.time', side_effect=time_values): # Simulate time passing
result = await download_model(
@ -130,9 +131,9 @@ async def test_download_model_url_request_failure(temp_dir):
fake_paths = FolderNames.from_dict({'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)})
# Mock the create_model_path function
with patch('comfy.model_filemanager.create_model_path', return_value='/mock/path/model.safetensors'), \
patch('comfy.model_filemanager.check_file_exists', return_value=None), \
patch('comfy.cmd.folder_paths.folder_names_and_paths', fake_paths):
with context_folder_names_and_paths(fake_paths), \
patch('comfy.model_filemanager.create_model_path', return_value='/mock/path/model.safetensors'), \
patch('comfy.model_filemanager.check_file_exists', return_value=None):
# Call the function
result = await download_model(
mock_get,