Fix tests

This commit is contained in:
doctorpangloss 2024-10-29 12:29:02 -07:00
parent 4a13766d14
commit b42e59d602
3 changed files with 105 additions and 26 deletions

View File

@ -5,7 +5,7 @@ import mimetypes
import os import os
import time import time
from contextlib import nullcontext from contextlib import nullcontext
from pathlib import Path from pathlib import Path, PurePosixPath
from typing import Optional, List, Literal from typing import Optional, List, Literal
from ..cli_args_types import Configuration from ..cli_args_types import Configuration
@ -19,6 +19,8 @@ from ..execution_context import current_execution_context
_module_properties = create_module_properties() _module_properties = create_module_properties()
logger = logging.getLogger(__name__)
@_module_properties.getter @_module_properties.getter
def _supported_pt_extensions() -> frozenset[str]: def _supported_pt_extensions() -> frozenset[str]:
@ -36,7 +38,25 @@ def _base_path():
return _folder_names_and_paths().base_paths[0] return _folder_names_and_paths().base_paths[0]
def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optional[Configuration] = None): def _resolve_path_with_compatibility(path: Path | str) -> PurePosixPath | Path:
"""
Absolute posix style paths (aka, paths starting with `/`) are always returned as-is, otherwise this is resolved.
:param path: a path or string to a path
:return: the resolved path
"""
if isinstance(path, PurePosixPath) and path.is_absolute():
return path
return Path(path).resolve()
def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optional[Configuration] = None, create_all_directories=False):
"""
Populates the folder names and paths object with the default, upstream model directories and custom_nodes directory.
:param folder_names_and_paths: the object to populate with paths
:param configuration: a configuration whose base_paths and other path settings will be used to set the values on this object
:param create_all_directories: create all the possible directories by calling create_directories() after the object is populated
:return:
"""
from ..cmd.main_pre import args from ..cmd.main_pre import args
configuration = configuration or args configuration = configuration or args
base_paths = [Path(configuration.cwd) if configuration.cwd is not None else None] + configuration.base_paths base_paths = [Path(configuration.cwd) if configuration.cwd is not None else None] + configuration.base_paths
@ -69,7 +89,20 @@ def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optio
if "HF_HUB_CACHE" in os.environ: if "HF_HUB_CACHE" in os.environ:
hf_cache_paths.additional_absolute_directory_paths.add(os.environ.get("HF_HUB_CACHE")) hf_cache_paths.additional_absolute_directory_paths.add(os.environ.get("HF_HUB_CACHE"))
folder_names_and_paths.add(hf_cache_paths) folder_names_and_paths.add(hf_cache_paths)
create_directories(folder_names_and_paths) if create_all_directories:
create_directories(folder_names_and_paths)
if configuration.output_directory is not None:
folder_names_and_paths.application_paths.output_directory = Path(configuration.output_directory)
if configuration.input_directory is not None:
folder_names_and_paths.application_paths.input_directory = Path(configuration.input_directory)
if configuration.temp_directory is not None:
folder_names_and_paths.application_paths.temp_directory = Path(configuration.temp_directory)
if configuration.user_directory is not None:
folder_names_and_paths.application_paths.user_directory = Path(configuration.user_directory)
@_module_properties.getter @_module_properties.getter
@ -84,22 +117,22 @@ def _models_dir():
@_module_properties.getter @_module_properties.getter
def _user_directory() -> str: def _user_directory() -> str:
return str(Path(current_execution_context().folder_names_and_paths.application_paths.user_directory).resolve()) return str(_resolve_path_with_compatibility(current_execution_context().folder_names_and_paths.application_paths.user_directory))
@_module_properties.getter @_module_properties.getter
def _temp_directory() -> str: def _temp_directory() -> str:
return str(Path(current_execution_context().folder_names_and_paths.application_paths.temp_directory).resolve()) return str(_resolve_path_with_compatibility(current_execution_context().folder_names_and_paths.application_paths.temp_directory))
@_module_properties.getter @_module_properties.getter
def _input_directory() -> str: def _input_directory() -> str:
return str(Path(current_execution_context().folder_names_and_paths.application_paths.input_directory).resolve()) return str(_resolve_path_with_compatibility(current_execution_context().folder_names_and_paths.application_paths.input_directory))
@_module_properties.getter @_module_properties.getter
def _output_directory() -> str: def _output_directory() -> str:
return str(Path(current_execution_context().folder_names_and_paths.application_paths.output_directory).resolve()) return str(_resolve_path_with_compatibility(current_execution_context().folder_names_and_paths.application_paths.output_directory))
@_deprecate_method(version="0.2.3", message="Mapping of previous folder names is already done by other mechanisms.") @_deprecate_method(version="0.2.3", message="Mapping of previous folder names is already done by other mechanisms.")
@ -121,19 +154,19 @@ def set_input_directory(input_dir: str | Path):
def get_output_directory() -> str: def get_output_directory() -> str:
return str(Path(_folder_names_and_paths().application_paths.output_directory).resolve()) return str(_resolve_path_with_compatibility(_folder_names_and_paths().application_paths.output_directory))
def get_temp_directory() -> str: def get_temp_directory() -> str:
return str(Path(_folder_names_and_paths().application_paths.temp_directory).resolve()) return str(_resolve_path_with_compatibility(_folder_names_and_paths().application_paths.temp_directory))
def get_input_directory() -> str: def get_input_directory() -> str:
return str(Path(_folder_names_and_paths().application_paths.input_directory).resolve()) return str(_resolve_path_with_compatibility(_folder_names_and_paths().application_paths.input_directory))
def get_user_directory() -> str: def get_user_directory() -> str:
return str(Path(_folder_names_and_paths().application_paths.user_directory).resolve()) return str(_resolve_path_with_compatibility(_folder_names_and_paths().application_paths.user_directory))
def set_user_directory(user_dir: str | Path) -> None: def set_user_directory(user_dir: str | Path) -> None:
@ -223,9 +256,47 @@ def get_folder_paths(folder_name) -> List[str]:
return [path for path in _folder_names_and_paths()[folder_name].paths] return [path for path in _folder_names_and_paths()[folder_name].paths]
@_deprecate_method(version="0.2.3", message="Not supported") @_deprecate_method(version="1.0.0", message="Use os.scandir instead.")
def recursive_search(directory, excluded_dir_names=None) -> tuple[list[str], dict[str, float]]: def recursive_search(directory, excluded_dir_names=None) -> tuple[list[str], dict[str, float]]:
raise NotImplemented("Unsupported method") if not os.path.isdir(directory):
return [], {}
if excluded_dir_names is None:
excluded_dir_names = []
result = []
dirs = {}
# Attempt to add the initial directory to dirs with error handling
try:
dirs[directory] = os.path.getmtime(directory)
except FileNotFoundError:
logger.warning(f"Warning: Unable to access {directory}. Skipping this path.")
logger.debug("recursive file list on directory {}".format(directory))
dirpath: str
subdirs: list[str]
filenames: list[str]
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
for file_name in filenames:
try:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path)
except:
logger.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
continue
for d in subdirs:
path: str = os.path.join(dirpath, d)
try:
dirs[path] = os.path.getmtime(path)
except FileNotFoundError:
logger.warning(f"Warning: Unable to access {path}. Skipping this path.")
continue
logger.debug("found {} files".format(len(result)))
return result, dirs
def filter_files_extensions(files, extensions): def filter_files_extensions(files, extensions):
@ -343,12 +414,11 @@ def _cache_helper():
init_default_paths(_folder_names_and_paths()) init_default_paths(_folder_names_and_paths())
__all__ = [ __all__ = [
# Properties (stripped leading underscore) "supported_pt_extensions",
"supported_pt_extensions", # from _supported_pt_extensions "extension_mimetypes_cache",
"extension_mimetypes_cache", # from _extension_mimetypes_cache "base_path",
"base_path", # from _base_path "folder_names_and_paths",
"folder_names_and_paths", # from _folder_names_and_paths "models_dir",
"models_dir", # from _models_dir
"user_directory", "user_directory",
"output_directory", "output_directory",
"temp_directory", "temp_directory",

View File

@ -6,9 +6,9 @@ import os
import typing import typing
import weakref import weakref
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path, PurePosixPath
from typing import Any, NamedTuple, Optional, Iterable from typing import Any, NamedTuple, Optional, Iterable
from pathlib import Path
from .platform_path import construct_path from .platform_path import construct_path
supported_pt_extensions = frozenset(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft', ".index.json"]) supported_pt_extensions = frozenset(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft', ".index.json"])
@ -288,7 +288,15 @@ class FolderNames:
:param folder_paths_dict: A dictionary :param folder_paths_dict: A dictionary
:return: A FolderNames object :return: A FolderNames object
""" """
raise NotImplementedError() fn = FolderNames()
for folder_name, (paths, extensions) in folder_paths_dict.items():
fn.add(
ModelPaths(folder_names=[folder_name],
supported_extensions=set(extensions),
additional_relative_directory_paths=set(path for path in paths if not Path(path).is_absolute()),
additional_absolute_directory_paths=set(path for path in paths if Path(path).is_absolute()), folder_names_are_relative_directory_paths_too=False
))
return fn
def __getitem__(self, folder_name) -> FolderPathsTuple: def __getitem__(self, folder_name) -> FolderPathsTuple:
if not isinstance(folder_name, str) or folder_name is None: if not isinstance(folder_name, str) or folder_name is None:

View File

@ -9,6 +9,7 @@ from aiohttp import ClientResponse
from comfy.cmd import folder_paths from comfy.cmd import folder_paths
from comfy.component_model.folder_path_types import FolderNames, FolderPathsTuple, supported_pt_extensions from comfy.component_model.folder_path_types import FolderNames, FolderPathsTuple, supported_pt_extensions
from comfy.execution_context import context_folder_names_and_paths
from comfy.model_filemanager import download_model, track_download_progress, \ from comfy.model_filemanager import download_model, track_download_progress, \
create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
@ -74,9 +75,9 @@ async def test_download_model_success(temp_dir):
fake_paths = FolderNames.from_dict({'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}) fake_paths = FolderNames.from_dict({'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)})
with patch('comfy.model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'model.sft')), \ with context_folder_names_and_paths(fake_paths), \
patch('comfy.model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'model.sft')), \
patch('comfy.model_filemanager.check_file_exists', return_value=None), \ patch('comfy.model_filemanager.check_file_exists', return_value=None), \
patch('comfy.cmd.folder_paths.folder_names_and_paths', fake_paths), \
patch('time.time', side_effect=time_values): # Simulate time passing patch('time.time', side_effect=time_values): # Simulate time passing
result = await download_model( result = await download_model(
@ -130,9 +131,9 @@ async def test_download_model_url_request_failure(temp_dir):
fake_paths = FolderNames.from_dict({'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)}) fake_paths = FolderNames.from_dict({'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)})
# Mock the create_model_path function # Mock the create_model_path function
with patch('comfy.model_filemanager.create_model_path', return_value='/mock/path/model.safetensors'), \ with context_folder_names_and_paths(fake_paths), \
patch('comfy.model_filemanager.check_file_exists', return_value=None), \ patch('comfy.model_filemanager.create_model_path', return_value='/mock/path/model.safetensors'), \
patch('comfy.cmd.folder_paths.folder_names_and_paths', fake_paths): patch('comfy.model_filemanager.check_file_exists', return_value=None):
# Call the function # Call the function
result = await download_model( result = await download_model(
mock_get, mock_get,