From d9ba79538523d697a447e13a0e0cb89fc7cf41d5 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Thu, 1 Aug 2024 18:28:51 -0700 Subject: [PATCH] Fixes for tests and completing merge - huggingface cache is now better used on platforms that support symlinking and the files you are requesting already exist in the cache - absolute imports were changed to relative in the correct places - StringEnumRequestParameter has a special case in validation - fix model_management whitespace issue - fix comfy.ops references --- comfy/cmd/execution.py | 3 +- comfy/ldm/hydit/models.py | 2 +- comfy/ldm/hydit/poolers.py | 6 +- comfy/model_downloader.py | 70 +++++++++++++------ comfy/model_management.py | 41 ++++++----- folder_paths.py | 0 tests/conftest.py | 6 +- .../downloader/test_huggingface_downloads.py | 18 +++-- tests/unit/test_openapi_nodes.py | 32 ++++++++- 9 files changed, 123 insertions(+), 55 deletions(-) delete mode 100644 folder_paths.py diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index af6e00d13..58edc9f3e 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -648,7 +648,8 @@ def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTupl r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES received_type = r[val[1]] received_types[x] = received_type - if 'input_types' not in validate_function_inputs and received_type != type_input: + any_enum = received_type == [] and (isinstance(type_input, list) or isinstance(type_input, tuple)) + if 'input_types' not in validate_function_inputs and received_type != type_input and not any_enum: details = f"{x}, {received_type} != {type_input}" error = { "type": "return_type_mismatch", diff --git a/comfy/ldm/hydit/models.py b/comfy/ldm/hydit/models.py index 2170f9475..f245b37fb 100644 --- a/comfy/ldm/hydit/models.py +++ b/comfy/ldm/hydit/models.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .... import ops +from ... import ops from ..modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm from ..modules.diffusionmodules.util import timestep_embedding from torch.utils import checkpoint diff --git a/comfy/ldm/hydit/poolers.py b/comfy/ldm/hydit/poolers.py index f5e5b406f..67987c7a6 100644 --- a/comfy/ldm/hydit/poolers.py +++ b/comfy/ldm/hydit/poolers.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from comfy.ldm.modules.attention import optimized_attention -import comfy.ops +from ..modules.attention import optimized_attention +from ... import ops class AttentionPool(nn.Module): def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None, dtype=None, device=None, operations=None): @@ -19,7 +19,7 @@ class AttentionPool(nn.Module): x = x[:,:self.positional_embedding.shape[0] - 1] x = x.permute(1, 0, 2) # NLC -> LNC x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC - x = x + comfy.ops.cast_to_input(self.positional_embedding[:, None, :], x) # (L+1)NC + x = x + ops.cast_to_input(self.positional_embedding[:, None, :], x) # (L+1)NC q = self.q_proj(x[:1]) k = self.k_proj(x) diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index 50d89219e..df61cc198 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -4,15 +4,17 @@ import collections import logging import operator import os +import shutil from functools import reduce from itertools import chain from os.path import join from pathlib import Path -from typing import List, Any, Optional, Sequence, Final, Set, MutableSequence +from typing import List, Optional, Sequence, Final, Set, MutableSequence import tqdm from huggingface_hub import hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem -from huggingface_hub.utils import GatedRepoError +from huggingface_hub.file_download import are_symlinks_supported +from huggingface_hub.utils import GatedRepoError, LocalEntryNotFoundError from requests import Session from safetensors import safe_open from safetensors.torch import save_file @@ -82,16 +84,31 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[ file_size = os.stat(path, follow_symlinks=True).st_size if os.path.isfile(path) else None except: file_size = None - if os.path.isfile(path) and known_file.size is None or file_size == known_file.size: + if os.path.isfile(path) and file_size == known_file.size: return path - path = hf_hub_download(repo_id=known_file.repo_id, - filename=known_file.filename, - # todo: in the latest huggingface implementation, this causes files to be downloaded as though the destination is the cache dir, rather than a local directory linking to a cache dir - local_dir=hf_destination_dir, - repo_type=known_file.repo_type, - revision=known_file.revision, - ) + cache_hit = False + try: + if not are_symlinks_supported(): + raise PermissionError("no symlink support") + # always retrieve this from the cache if it already exists there + path = hf_hub_download(repo_id=known_file.repo_id, + filename=known_file.filename, + repo_type=known_file.repo_type, + revision=known_file.revision, + local_files_only=True, + ) + logging.info(f"hf_hub_download cache hit for {known_file.repo_id}/{known_file.filename}") + if linked_filename is None: + linked_filename = known_file.filename + cache_hit = True + except (LocalEntryNotFoundError, PermissionError): + path = hf_hub_download(repo_id=known_file.repo_id, + filename=known_file.filename, + local_dir=hf_destination_dir, + repo_type=known_file.repo_type, + revision=known_file.revision, + ) if known_file.convert_to_16_bit and file_size is not None and file_size != 0: tensors = {} @@ -103,17 +120,30 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[ del x pb.update() - save_file(tensors, path) + # always save converted files to the destination so that the huggingface cache is not corrupted + save_file(tensors, os.path.join(hf_destination_dir, known_file.filename)) for _, v in tensors.items(): del v logging.info(f"Converted {path} to 16 bit, size is {os.stat(path, follow_symlinks=True).st_size}") - try: - if linked_filename is not None: - os.symlink(os.path.join(hf_destination_dir, known_file.filename), os.path.join(this_model_directory, linked_filename)) - except Exception as exc_info: - logging.error(f"Failed to link file with alternative download save name in a way that is compatible with Hugging Face caching {repr(known_file)}", exc_info=exc_info) + link_exc_info = None + if linked_filename is not None: + destination_link = os.path.join(this_model_directory, linked_filename) + try: + os.makedirs(this_model_directory, exist_ok=True) + os.symlink(path, destination_link) + except WindowsError: + try: + os.link(path, destination_link) + except Exception as exc_info: + link_exc_info = exc_info + if cache_hit: + shutil.copyfile(path, destination_link) + except Exception as exc_info: + link_exc_info = exc_info + if link_exc_info is not None: + logging.error(f"Failed to link file with alternative download save name in a way that is compatible with Hugging Face caching {repr(known_file)}. If cache_hit={cache_hit} is True, the file was copied into the destination.", exc_info=exc_info) else: url: Optional[str] = None save_filename = known_file.save_with_filename or known_file.filename @@ -463,7 +493,7 @@ def get_huggingface_repo_list(*extra_cache_dirs: str) -> List[str]: cache_item.repo_id for cache_item in \ reduce(operator.or_, map(lambda cache_info: cache_info.repos, [scan_cache_dir()] + [scan_cache_dir(cache_dir=cache_dir) for cache_dir in extra_cache_dirs if os.path.isdir(cache_dir)])) - if cache_item.repo_type == "model" + if cache_item.repo_type == "model" or cache_item.repo_type == "space" ) # also check local-dir style directories @@ -489,9 +519,9 @@ def get_huggingface_repo_list(*extra_cache_dirs: str) -> List[str]: return list(existing_repo_ids | existing_local_dir_repos | known_repo_ids) -def get_or_download_huggingface_repo(repo_id: str) -> Optional[str]: - cache_dirs = folder_paths.get_folder_paths("huggingface_cache") - local_dirs = folder_paths.get_folder_paths("huggingface") +def get_or_download_huggingface_repo(repo_id: str, cache_dirs: Optional[list] = None, local_dirs: Optional[list] = None) -> Optional[str]: + cache_dirs = cache_dirs or folder_paths.get_folder_paths("huggingface_cache") + local_dirs = local_dirs or folder_paths.get_folder_paths("huggingface") cache_dirs_snapshots, local_dirs_snapshots = _get_cache_hits(cache_dirs, local_dirs, repo_id) local_dirs_cache_hit = len(local_dirs_snapshots) > 0 diff --git a/comfy/model_management.py b/comfy/model_management.py index 08f9b471c..7226b0425 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -453,28 +453,28 @@ def load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0, else: minimum_memory_required = max(inference_memory, minimum_memory_required) - models = set(models) - models_to_load = [] - models_already_loaded = [] - for x in models: - loaded_model = LoadedModel(x) - loaded = None + models = set(models) + models_to_load = [] + models_already_loaded = [] + for x in models: + loaded_model = LoadedModel(x) + loaded = None - try: - loaded_model_index = current_loaded_models.index(loaded_model) - except ValueError: - loaded_model_index = None + try: + loaded_model_index = current_loaded_models.index(loaded_model) + except ValueError: + loaded_model_index = None - if loaded_model_index is not None: - loaded = current_loaded_models[loaded_model_index] - if loaded.should_reload_model(force_patch_weights=force_patch_weights): # TODO: cleanup this model reload logic - current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True) - loaded = None - else: - loaded.currently_used = True - models_already_loaded.append(loaded) - if loaded is None: - models_to_load.append(loaded_model) + if loaded_model_index is not None: + loaded = current_loaded_models[loaded_model_index] + if loaded.should_reload_model(force_patch_weights=force_patch_weights): # TODO: cleanup this model reload logic + current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True) + loaded = None + else: + loaded.currently_used = True + models_already_loaded.append(loaded) + if loaded is None: + models_to_load.append(loaded_model) models_freed: List[LoadedModel] = [] try: @@ -513,7 +513,6 @@ def load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0, current_free_mem = get_free_memory(torch_dev) lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required))) if model_size <= lowvram_model_memory: # only switch to lowvram if really necessary - lowvram_model_memory = 0 if vram_set_state == VRAMState.NO_VRAM: diff --git a/folder_paths.py b/folder_paths.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/conftest.py b/tests/conftest.py index 1c8c973ec..f36264f17 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -89,9 +89,13 @@ def has_gpu() -> bool: @pytest.fixture(scope="module", autouse=False) def frontend_backend_worker_with_rabbitmq(tmp_path_factory) -> str: """ - starts a frontend and backend worker against a started rabbitmq, and yields the address of the frontend + populates the cache with the sdxl checkpoints, starts a frontend and backend worker against a started rabbitmq, and yields the address of the frontend :return: """ + from huggingface_hub import hf_hub_download + hf_hub_download("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors") + hf_hub_download("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors") + tmp_path = tmp_path_factory.mktemp("comfy_background_server") processes_to_close: List[subprocess.Popen] = [] from testcontainers.rabbitmq import RabbitMqContainer diff --git a/tests/downloader/test_huggingface_downloads.py b/tests/downloader/test_huggingface_downloads.py index 4a6225380..fa881ad22 100644 --- a/tests/downloader/test_huggingface_downloads.py +++ b/tests/downloader/test_huggingface_downloads.py @@ -1,3 +1,4 @@ +import logging import os import shutil @@ -45,7 +46,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text @pytest.mark.asyncio -@pytest.mark.skip("flakey") def test_known_repos(tmp_path_factory): prev_hub_cache = os.getenv("HF_HUB_CACHE") os.environ["HF_HUB_CACHE"] = str(tmp_path_factory.mktemp("huggingface_root_cache")) @@ -66,21 +66,25 @@ def test_known_repos(tmp_path_factory): _delete_repo_from_huggingface_cache(test_repo_id, test_cache_dir) args.disable_known_models = False try: - folder_paths.folder_names_and_paths["huggingface"] += FolderPathsTuple("huggingface", [test_local_dir], {""}) - folder_paths.folder_names_and_paths["huggingface_cache"] += FolderPathsTuple("huggingface_cache", [test_cache_dir], {""}) + folder_paths.folder_names_and_paths["huggingface"] = FolderPathsTuple("huggingface", [test_local_dir], {""}) + folder_paths.folder_names_and_paths["huggingface_cache"] = FolderPathsTuple("huggingface_cache", [test_cache_dir], {""}) cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id) assert len(cache_hits) == 0, "not downloaded yet" assert len(locals_hits) == 0, "not downloaded yet" # test downloading the repo and observing a cache hit on second access - existing_repos = get_huggingface_repo_list() - assert test_repo_id not in existing_repos + try: + KNOWN_HUGGINGFACE_MODEL_REPOS.remove(test_repo_id) + logging.error("unexpected, the test_repo_id was already in the KNOWN_HUGGINGFACE_MODEL_REPOS symbol") + except KeyError: + known_repos = get_huggingface_repo_list() + assert test_repo_id not in known_repos # best to import this at the time that it is run, not when the test is initialized KNOWN_HUGGINGFACE_MODEL_REPOS.add(test_repo_id) - existing_repos = get_huggingface_repo_list() - assert test_repo_id in existing_repos + known_repos = get_huggingface_repo_list() + assert test_repo_id in known_repos cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id) assert len(cache_hits) == len(locals_hits) == 0, "not downloaded yet" diff --git a/tests/unit/test_openapi_nodes.py b/tests/unit/test_openapi_nodes.py index 2b92ab316..e8b85d187 100644 --- a/tests/unit/test_openapi_nodes.py +++ b/tests/unit/test_openapi_nodes.py @@ -10,6 +10,7 @@ from PIL import Image from freezegun import freeze_time from comfy.cmd import folder_paths +from comfy.component_model.executor_types import ValidateInputsTuple from comfy_extras.nodes.nodes_open_api import SaveImagesResponse, IntRequestParameter, FloatRequestParameter, \ StringRequestParameter, HashImage, StringPosixPathJoin, LegacyOutputURIs, DevNullUris, StringJoin, StringToUri, \ UriFormat, ImageExifMerge, ImageExifCreationDateAndBatchNumber, ImageExif, ImageExifUncommon, \ @@ -123,7 +124,36 @@ def test_string_enum_request_parameter(): n = StringEnumRequestParameter() v, = n.execute(value="test", name="test") assert v == "test" - # todo: check that a graph that uses this in a checkpoint is valid + prompt = { + "1": { + "inputs": { + "value": "euler", + "name": "sampler_name", + "title": "KSampler Node Sampler", + "description": + "This allows users to select a sampler for generating images with Latent Diffusion Models, including Stable Diffusion, ComfyUI, and SDXL. \n\nChange this only if explicitly requested by the user.\n\nList of sampler choice (this parameter): valid choices for scheduler (value for scheduler parameter).\n\n- euler: normal, karras, exponential, sgm_uniform, simple, ddim_uniform\n- euler_ancestral: normal, karras\n- heun: normal, karras\n- heunpp2: normal, karras\n- dpm_2: normal, karras\n- dpm_2_ancestral: normal, karras\n- lms: normal, karras\n- dpm_fast: normal, exponential\n- dpm_adaptive: normal, exponential\n- dpmpp_2s_ancestral: karras, exponential\n- dpmpp_sde: karras, exponential\n- dpmpp_sde_gpu: karras, exponential\n- dpmpp_2m: karras, sgm_uniform\n- dpmpp_2m_sde: karras, sgm_uniform\n- dpmpp_2m_sde_gpu: karras, sgm_uniform\n- dpmpp_3m_sde: karras, sgm_uniform\n- dpmpp_3m_sde_gpu: karras, sgm_uniform\n- ddpm: normal, simple\n- lcm: normal, exponential\n- ddim: normal, ddim_uniform\n- uni_pc: normal, karras, exponential\n- uni_pc_bh2: normal, karras, exponential", + "__required": True, + }, + "class_type": "StringEnumRequestParameter", + "_meta": { + "title": "StringEnumRequestParameter", + }, + }, + "2": { + "inputs": { + "sampler_name": ["1", 0], + }, + "class_type": "KSamplerSelect", + "_meta": { + "title": "KSamplerSelect", + }, + }, + } + from comfy.cmd.execution import validate_inputs + validated: dict[str, ValidateInputsTuple] = {} + validated["1"] = validate_inputs(prompt, "1", validated) + validated["2"] = validate_inputs(prompt, "2", validated) + assert validated["2"].valid @pytest.mark.skip("issues")