mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 14:50:49 +08:00
Fix tests
This commit is contained in:
parent
4a13766d14
commit
b42e59d602
@ -5,7 +5,7 @@ import mimetypes
|
||||
import os
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from pathlib import Path, PurePosixPath
|
||||
from typing import Optional, List, Literal
|
||||
|
||||
from ..cli_args_types import Configuration
|
||||
@ -19,6 +19,8 @@ from ..execution_context import current_execution_context
|
||||
|
||||
_module_properties = create_module_properties()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@_module_properties.getter
|
||||
def _supported_pt_extensions() -> frozenset[str]:
|
||||
@ -36,7 +38,25 @@ def _base_path():
|
||||
return _folder_names_and_paths().base_paths[0]
|
||||
|
||||
|
||||
def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optional[Configuration] = None):
|
||||
def _resolve_path_with_compatibility(path: Path | str) -> PurePosixPath | Path:
|
||||
"""
|
||||
Absolute posix style paths (aka, paths starting with `/`) are always returned as-is, otherwise this is resolved.
|
||||
:param path: a path or string to a path
|
||||
:return: the resolved path
|
||||
"""
|
||||
if isinstance(path, PurePosixPath) and path.is_absolute():
|
||||
return path
|
||||
return Path(path).resolve()
|
||||
|
||||
|
||||
def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optional[Configuration] = None, create_all_directories=False):
|
||||
"""
|
||||
Populates the folder names and paths object with the default, upstream model directories and custom_nodes directory.
|
||||
:param folder_names_and_paths: the object to populate with paths
|
||||
:param configuration: a configuration whose base_paths and other path settings will be used to set the values on this object
|
||||
:param create_all_directories: create all the possible directories by calling create_directories() after the object is populated
|
||||
:return:
|
||||
"""
|
||||
from ..cmd.main_pre import args
|
||||
configuration = configuration or args
|
||||
base_paths = [Path(configuration.cwd) if configuration.cwd is not None else None] + configuration.base_paths
|
||||
@ -69,7 +89,20 @@ def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optio
|
||||
if "HF_HUB_CACHE" in os.environ:
|
||||
hf_cache_paths.additional_absolute_directory_paths.add(os.environ.get("HF_HUB_CACHE"))
|
||||
folder_names_and_paths.add(hf_cache_paths)
|
||||
create_directories(folder_names_and_paths)
|
||||
if create_all_directories:
|
||||
create_directories(folder_names_and_paths)
|
||||
|
||||
if configuration.output_directory is not None:
|
||||
folder_names_and_paths.application_paths.output_directory = Path(configuration.output_directory)
|
||||
|
||||
if configuration.input_directory is not None:
|
||||
folder_names_and_paths.application_paths.input_directory = Path(configuration.input_directory)
|
||||
|
||||
if configuration.temp_directory is not None:
|
||||
folder_names_and_paths.application_paths.temp_directory = Path(configuration.temp_directory)
|
||||
|
||||
if configuration.user_directory is not None:
|
||||
folder_names_and_paths.application_paths.user_directory = Path(configuration.user_directory)
|
||||
|
||||
|
||||
@_module_properties.getter
|
||||
@ -84,22 +117,22 @@ def _models_dir():
|
||||
|
||||
@_module_properties.getter
|
||||
def _user_directory() -> str:
|
||||
return str(Path(current_execution_context().folder_names_and_paths.application_paths.user_directory).resolve())
|
||||
return str(_resolve_path_with_compatibility(current_execution_context().folder_names_and_paths.application_paths.user_directory))
|
||||
|
||||
|
||||
@_module_properties.getter
|
||||
def _temp_directory() -> str:
|
||||
return str(Path(current_execution_context().folder_names_and_paths.application_paths.temp_directory).resolve())
|
||||
return str(_resolve_path_with_compatibility(current_execution_context().folder_names_and_paths.application_paths.temp_directory))
|
||||
|
||||
|
||||
@_module_properties.getter
|
||||
def _input_directory() -> str:
|
||||
return str(Path(current_execution_context().folder_names_and_paths.application_paths.input_directory).resolve())
|
||||
return str(_resolve_path_with_compatibility(current_execution_context().folder_names_and_paths.application_paths.input_directory))
|
||||
|
||||
|
||||
@_module_properties.getter
|
||||
def _output_directory() -> str:
|
||||
return str(Path(current_execution_context().folder_names_and_paths.application_paths.output_directory).resolve())
|
||||
return str(_resolve_path_with_compatibility(current_execution_context().folder_names_and_paths.application_paths.output_directory))
|
||||
|
||||
|
||||
@_deprecate_method(version="0.2.3", message="Mapping of previous folder names is already done by other mechanisms.")
|
||||
@ -121,19 +154,19 @@ def set_input_directory(input_dir: str | Path):
|
||||
|
||||
|
||||
def get_output_directory() -> str:
|
||||
return str(Path(_folder_names_and_paths().application_paths.output_directory).resolve())
|
||||
return str(_resolve_path_with_compatibility(_folder_names_and_paths().application_paths.output_directory))
|
||||
|
||||
|
||||
def get_temp_directory() -> str:
|
||||
return str(Path(_folder_names_and_paths().application_paths.temp_directory).resolve())
|
||||
return str(_resolve_path_with_compatibility(_folder_names_and_paths().application_paths.temp_directory))
|
||||
|
||||
|
||||
def get_input_directory() -> str:
|
||||
return str(Path(_folder_names_and_paths().application_paths.input_directory).resolve())
|
||||
return str(_resolve_path_with_compatibility(_folder_names_and_paths().application_paths.input_directory))
|
||||
|
||||
|
||||
def get_user_directory() -> str:
|
||||
return str(Path(_folder_names_and_paths().application_paths.user_directory).resolve())
|
||||
return str(_resolve_path_with_compatibility(_folder_names_and_paths().application_paths.user_directory))
|
||||
|
||||
|
||||
def set_user_directory(user_dir: str | Path) -> None:
|
||||
@ -223,9 +256,47 @@ def get_folder_paths(folder_name) -> List[str]:
|
||||
return [path for path in _folder_names_and_paths()[folder_name].paths]
|
||||
|
||||
|
||||
@_deprecate_method(version="0.2.3", message="Not supported")
|
||||
@_deprecate_method(version="1.0.0", message="Use os.scandir instead.")
|
||||
def recursive_search(directory, excluded_dir_names=None) -> tuple[list[str], dict[str, float]]:
|
||||
raise NotImplemented("Unsupported method")
|
||||
if not os.path.isdir(directory):
|
||||
return [], {}
|
||||
|
||||
if excluded_dir_names is None:
|
||||
excluded_dir_names = []
|
||||
|
||||
result = []
|
||||
dirs = {}
|
||||
|
||||
# Attempt to add the initial directory to dirs with error handling
|
||||
try:
|
||||
dirs[directory] = os.path.getmtime(directory)
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"Warning: Unable to access {directory}. Skipping this path.")
|
||||
|
||||
logger.debug("recursive file list on directory {}".format(directory))
|
||||
dirpath: str
|
||||
subdirs: list[str]
|
||||
filenames: list[str]
|
||||
|
||||
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
|
||||
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
|
||||
for file_name in filenames:
|
||||
try:
|
||||
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
|
||||
result.append(relative_path)
|
||||
except:
|
||||
logger.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
|
||||
continue
|
||||
|
||||
for d in subdirs:
|
||||
path: str = os.path.join(dirpath, d)
|
||||
try:
|
||||
dirs[path] = os.path.getmtime(path)
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"Warning: Unable to access {path}. Skipping this path.")
|
||||
continue
|
||||
logger.debug("found {} files".format(len(result)))
|
||||
return result, dirs
|
||||
|
||||
|
||||
def filter_files_extensions(files, extensions):
|
||||
@ -343,12 +414,11 @@ def _cache_helper():
|
||||
init_default_paths(_folder_names_and_paths())
|
||||
|
||||
__all__ = [
|
||||
# Properties (stripped leading underscore)
|
||||
"supported_pt_extensions", # from _supported_pt_extensions
|
||||
"extension_mimetypes_cache", # from _extension_mimetypes_cache
|
||||
"base_path", # from _base_path
|
||||
"folder_names_and_paths", # from _folder_names_and_paths
|
||||
"models_dir", # from _models_dir
|
||||
"supported_pt_extensions",
|
||||
"extension_mimetypes_cache",
|
||||
"base_path",
|
||||
"folder_names_and_paths",
|
||||
"models_dir",
|
||||
"user_directory",
|
||||
"output_directory",
|
||||
"temp_directory",
|
||||
|
||||
@ -6,9 +6,9 @@ import os
|
||||
import typing
|
||||
import weakref
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path, PurePosixPath
|
||||
from typing import Any, NamedTuple, Optional, Iterable
|
||||
|
||||
from pathlib import Path
|
||||
from .platform_path import construct_path
|
||||
|
||||
supported_pt_extensions = frozenset(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft', ".index.json"])
|
||||
@ -288,7 +288,15 @@ class FolderNames:
|
||||
:param folder_paths_dict: A dictionary
|
||||
:return: A FolderNames object
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
fn = FolderNames()
|
||||
for folder_name, (paths, extensions) in folder_paths_dict.items():
|
||||
fn.add(
|
||||
ModelPaths(folder_names=[folder_name],
|
||||
supported_extensions=set(extensions),
|
||||
additional_relative_directory_paths=set(path for path in paths if not Path(path).is_absolute()),
|
||||
additional_absolute_directory_paths=set(path for path in paths if Path(path).is_absolute()), folder_names_are_relative_directory_paths_too=False
|
||||
))
|
||||
return fn
|
||||
|
||||
def __getitem__(self, folder_name) -> FolderPathsTuple:
|
||||
if not isinstance(folder_name, str) or folder_name is None:
|
||||
|
||||
@ -9,6 +9,7 @@ from aiohttp import ClientResponse
|
||||
|
||||
from comfy.cmd import folder_paths
|
||||
from comfy.component_model.folder_path_types import FolderNames, FolderPathsTuple, supported_pt_extensions
|
||||
from comfy.execution_context import context_folder_names_and_paths
|
||||
from comfy.model_filemanager import download_model, track_download_progress, \
|
||||
create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename
|
||||
|
||||
@ -74,9 +75,9 @@ async def test_download_model_success(temp_dir):
|
||||
|
||||
fake_paths = FolderNames.from_dict({'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)})
|
||||
|
||||
with patch('comfy.model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'model.sft')), \
|
||||
with context_folder_names_and_paths(fake_paths), \
|
||||
patch('comfy.model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'model.sft')), \
|
||||
patch('comfy.model_filemanager.check_file_exists', return_value=None), \
|
||||
patch('comfy.cmd.folder_paths.folder_names_and_paths', fake_paths), \
|
||||
patch('time.time', side_effect=time_values): # Simulate time passing
|
||||
|
||||
result = await download_model(
|
||||
@ -130,9 +131,9 @@ async def test_download_model_url_request_failure(temp_dir):
|
||||
fake_paths = FolderNames.from_dict({'checkpoints': ([temp_dir], folder_paths.supported_pt_extensions)})
|
||||
|
||||
# Mock the create_model_path function
|
||||
with patch('comfy.model_filemanager.create_model_path', return_value='/mock/path/model.safetensors'), \
|
||||
patch('comfy.model_filemanager.check_file_exists', return_value=None), \
|
||||
patch('comfy.cmd.folder_paths.folder_names_and_paths', fake_paths):
|
||||
with context_folder_names_and_paths(fake_paths), \
|
||||
patch('comfy.model_filemanager.create_model_path', return_value='/mock/path/model.safetensors'), \
|
||||
patch('comfy.model_filemanager.check_file_exists', return_value=None):
|
||||
# Call the function
|
||||
result = await download_model(
|
||||
mock_get,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user