mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 14:50:49 +08:00
Fixes for tests and completing merge
- huggingface cache is now better used on platforms that support symlinking and the files you are requesting already exist in the cache - absolute imports were changed to relative in the correct places - StringEnumRequestParameter has a special case in validation - fix model_management whitespace issue - fix comfy.ops references
This commit is contained in:
parent
a44a039661
commit
d9ba795385
@ -648,7 +648,8 @@ def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTupl
|
||||
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
||||
received_type = r[val[1]]
|
||||
received_types[x] = received_type
|
||||
if 'input_types' not in validate_function_inputs and received_type != type_input:
|
||||
any_enum = received_type == [] and (isinstance(type_input, list) or isinstance(type_input, tuple))
|
||||
if 'input_types' not in validate_function_inputs and received_type != type_input and not any_enum:
|
||||
details = f"{x}, {received_type} != {type_input}"
|
||||
error = {
|
||||
"type": "return_type_mismatch",
|
||||
|
||||
@ -4,7 +4,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .... import ops
|
||||
from ... import ops
|
||||
from ..modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
|
||||
from ..modules.diffusionmodules.util import timestep_embedding
|
||||
from torch.utils import checkpoint
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.ops
|
||||
from ..modules.attention import optimized_attention
|
||||
from ... import ops
|
||||
|
||||
class AttentionPool(nn.Module):
|
||||
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None, dtype=None, device=None, operations=None):
|
||||
@ -19,7 +19,7 @@ class AttentionPool(nn.Module):
|
||||
x = x[:,:self.positional_embedding.shape[0] - 1]
|
||||
x = x.permute(1, 0, 2) # NLC -> LNC
|
||||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
|
||||
x = x + comfy.ops.cast_to_input(self.positional_embedding[:, None, :], x) # (L+1)NC
|
||||
x = x + ops.cast_to_input(self.positional_embedding[:, None, :], x) # (L+1)NC
|
||||
|
||||
q = self.q_proj(x[:1])
|
||||
k = self.k_proj(x)
|
||||
|
||||
@ -4,15 +4,17 @@ import collections
|
||||
import logging
|
||||
import operator
|
||||
import os
|
||||
import shutil
|
||||
from functools import reduce
|
||||
from itertools import chain
|
||||
from os.path import join
|
||||
from pathlib import Path
|
||||
from typing import List, Any, Optional, Sequence, Final, Set, MutableSequence
|
||||
from typing import List, Optional, Sequence, Final, Set, MutableSequence
|
||||
|
||||
import tqdm
|
||||
from huggingface_hub import hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem
|
||||
from huggingface_hub.utils import GatedRepoError
|
||||
from huggingface_hub.file_download import are_symlinks_supported
|
||||
from huggingface_hub.utils import GatedRepoError, LocalEntryNotFoundError
|
||||
from requests import Session
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
@ -82,16 +84,31 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
|
||||
file_size = os.stat(path, follow_symlinks=True).st_size if os.path.isfile(path) else None
|
||||
except:
|
||||
file_size = None
|
||||
if os.path.isfile(path) and known_file.size is None or file_size == known_file.size:
|
||||
if os.path.isfile(path) and file_size == known_file.size:
|
||||
return path
|
||||
|
||||
path = hf_hub_download(repo_id=known_file.repo_id,
|
||||
filename=known_file.filename,
|
||||
# todo: in the latest huggingface implementation, this causes files to be downloaded as though the destination is the cache dir, rather than a local directory linking to a cache dir
|
||||
local_dir=hf_destination_dir,
|
||||
repo_type=known_file.repo_type,
|
||||
revision=known_file.revision,
|
||||
)
|
||||
cache_hit = False
|
||||
try:
|
||||
if not are_symlinks_supported():
|
||||
raise PermissionError("no symlink support")
|
||||
# always retrieve this from the cache if it already exists there
|
||||
path = hf_hub_download(repo_id=known_file.repo_id,
|
||||
filename=known_file.filename,
|
||||
repo_type=known_file.repo_type,
|
||||
revision=known_file.revision,
|
||||
local_files_only=True,
|
||||
)
|
||||
logging.info(f"hf_hub_download cache hit for {known_file.repo_id}/{known_file.filename}")
|
||||
if linked_filename is None:
|
||||
linked_filename = known_file.filename
|
||||
cache_hit = True
|
||||
except (LocalEntryNotFoundError, PermissionError):
|
||||
path = hf_hub_download(repo_id=known_file.repo_id,
|
||||
filename=known_file.filename,
|
||||
local_dir=hf_destination_dir,
|
||||
repo_type=known_file.repo_type,
|
||||
revision=known_file.revision,
|
||||
)
|
||||
|
||||
if known_file.convert_to_16_bit and file_size is not None and file_size != 0:
|
||||
tensors = {}
|
||||
@ -103,17 +120,30 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
|
||||
del x
|
||||
pb.update()
|
||||
|
||||
save_file(tensors, path)
|
||||
# always save converted files to the destination so that the huggingface cache is not corrupted
|
||||
save_file(tensors, os.path.join(hf_destination_dir, known_file.filename))
|
||||
|
||||
for _, v in tensors.items():
|
||||
del v
|
||||
logging.info(f"Converted {path} to 16 bit, size is {os.stat(path, follow_symlinks=True).st_size}")
|
||||
|
||||
try:
|
||||
if linked_filename is not None:
|
||||
os.symlink(os.path.join(hf_destination_dir, known_file.filename), os.path.join(this_model_directory, linked_filename))
|
||||
except Exception as exc_info:
|
||||
logging.error(f"Failed to link file with alternative download save name in a way that is compatible with Hugging Face caching {repr(known_file)}", exc_info=exc_info)
|
||||
link_exc_info = None
|
||||
if linked_filename is not None:
|
||||
destination_link = os.path.join(this_model_directory, linked_filename)
|
||||
try:
|
||||
os.makedirs(this_model_directory, exist_ok=True)
|
||||
os.symlink(path, destination_link)
|
||||
except WindowsError:
|
||||
try:
|
||||
os.link(path, destination_link)
|
||||
except Exception as exc_info:
|
||||
link_exc_info = exc_info
|
||||
if cache_hit:
|
||||
shutil.copyfile(path, destination_link)
|
||||
except Exception as exc_info:
|
||||
link_exc_info = exc_info
|
||||
if link_exc_info is not None:
|
||||
logging.error(f"Failed to link file with alternative download save name in a way that is compatible with Hugging Face caching {repr(known_file)}. If cache_hit={cache_hit} is True, the file was copied into the destination.", exc_info=exc_info)
|
||||
else:
|
||||
url: Optional[str] = None
|
||||
save_filename = known_file.save_with_filename or known_file.filename
|
||||
@ -463,7 +493,7 @@ def get_huggingface_repo_list(*extra_cache_dirs: str) -> List[str]:
|
||||
cache_item.repo_id for cache_item in \
|
||||
reduce(operator.or_,
|
||||
map(lambda cache_info: cache_info.repos, [scan_cache_dir()] + [scan_cache_dir(cache_dir=cache_dir) for cache_dir in extra_cache_dirs if os.path.isdir(cache_dir)]))
|
||||
if cache_item.repo_type == "model"
|
||||
if cache_item.repo_type == "model" or cache_item.repo_type == "space"
|
||||
)
|
||||
|
||||
# also check local-dir style directories
|
||||
@ -489,9 +519,9 @@ def get_huggingface_repo_list(*extra_cache_dirs: str) -> List[str]:
|
||||
return list(existing_repo_ids | existing_local_dir_repos | known_repo_ids)
|
||||
|
||||
|
||||
def get_or_download_huggingface_repo(repo_id: str) -> Optional[str]:
|
||||
cache_dirs = folder_paths.get_folder_paths("huggingface_cache")
|
||||
local_dirs = folder_paths.get_folder_paths("huggingface")
|
||||
def get_or_download_huggingface_repo(repo_id: str, cache_dirs: Optional[list] = None, local_dirs: Optional[list] = None) -> Optional[str]:
|
||||
cache_dirs = cache_dirs or folder_paths.get_folder_paths("huggingface_cache")
|
||||
local_dirs = local_dirs or folder_paths.get_folder_paths("huggingface")
|
||||
cache_dirs_snapshots, local_dirs_snapshots = _get_cache_hits(cache_dirs, local_dirs, repo_id)
|
||||
|
||||
local_dirs_cache_hit = len(local_dirs_snapshots) > 0
|
||||
|
||||
@ -453,28 +453,28 @@ def load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0,
|
||||
else:
|
||||
minimum_memory_required = max(inference_memory, minimum_memory_required)
|
||||
|
||||
models = set(models)
|
||||
models_to_load = []
|
||||
models_already_loaded = []
|
||||
for x in models:
|
||||
loaded_model = LoadedModel(x)
|
||||
loaded = None
|
||||
models = set(models)
|
||||
models_to_load = []
|
||||
models_already_loaded = []
|
||||
for x in models:
|
||||
loaded_model = LoadedModel(x)
|
||||
loaded = None
|
||||
|
||||
try:
|
||||
loaded_model_index = current_loaded_models.index(loaded_model)
|
||||
except ValueError:
|
||||
loaded_model_index = None
|
||||
try:
|
||||
loaded_model_index = current_loaded_models.index(loaded_model)
|
||||
except ValueError:
|
||||
loaded_model_index = None
|
||||
|
||||
if loaded_model_index is not None:
|
||||
loaded = current_loaded_models[loaded_model_index]
|
||||
if loaded.should_reload_model(force_patch_weights=force_patch_weights): # TODO: cleanup this model reload logic
|
||||
current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True)
|
||||
loaded = None
|
||||
else:
|
||||
loaded.currently_used = True
|
||||
models_already_loaded.append(loaded)
|
||||
if loaded is None:
|
||||
models_to_load.append(loaded_model)
|
||||
if loaded_model_index is not None:
|
||||
loaded = current_loaded_models[loaded_model_index]
|
||||
if loaded.should_reload_model(force_patch_weights=force_patch_weights): # TODO: cleanup this model reload logic
|
||||
current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True)
|
||||
loaded = None
|
||||
else:
|
||||
loaded.currently_used = True
|
||||
models_already_loaded.append(loaded)
|
||||
if loaded is None:
|
||||
models_to_load.append(loaded_model)
|
||||
|
||||
models_freed: List[LoadedModel] = []
|
||||
try:
|
||||
@ -513,7 +513,6 @@ def load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0,
|
||||
current_free_mem = get_free_memory(torch_dev)
|
||||
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required)))
|
||||
if model_size <= lowvram_model_memory: # only switch to lowvram if really necessary
|
||||
|
||||
lowvram_model_memory = 0
|
||||
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
|
||||
@ -89,9 +89,13 @@ def has_gpu() -> bool:
|
||||
@pytest.fixture(scope="module", autouse=False)
|
||||
def frontend_backend_worker_with_rabbitmq(tmp_path_factory) -> str:
|
||||
"""
|
||||
starts a frontend and backend worker against a started rabbitmq, and yields the address of the frontend
|
||||
populates the cache with the sdxl checkpoints, starts a frontend and backend worker against a started rabbitmq, and yields the address of the frontend
|
||||
:return:
|
||||
"""
|
||||
from huggingface_hub import hf_hub_download
|
||||
hf_hub_download("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors")
|
||||
hf_hub_download("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors")
|
||||
|
||||
tmp_path = tmp_path_factory.mktemp("comfy_background_server")
|
||||
processes_to_close: List[subprocess.Popen] = []
|
||||
from testcontainers.rabbitmq import RabbitMqContainer
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
|
||||
@ -45,7 +46,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip("flakey")
|
||||
def test_known_repos(tmp_path_factory):
|
||||
prev_hub_cache = os.getenv("HF_HUB_CACHE")
|
||||
os.environ["HF_HUB_CACHE"] = str(tmp_path_factory.mktemp("huggingface_root_cache"))
|
||||
@ -66,21 +66,25 @@ def test_known_repos(tmp_path_factory):
|
||||
_delete_repo_from_huggingface_cache(test_repo_id, test_cache_dir)
|
||||
args.disable_known_models = False
|
||||
try:
|
||||
folder_paths.folder_names_and_paths["huggingface"] += FolderPathsTuple("huggingface", [test_local_dir], {""})
|
||||
folder_paths.folder_names_and_paths["huggingface_cache"] += FolderPathsTuple("huggingface_cache", [test_cache_dir], {""})
|
||||
folder_paths.folder_names_and_paths["huggingface"] = FolderPathsTuple("huggingface", [test_local_dir], {""})
|
||||
folder_paths.folder_names_and_paths["huggingface_cache"] = FolderPathsTuple("huggingface_cache", [test_cache_dir], {""})
|
||||
|
||||
cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id)
|
||||
assert len(cache_hits) == 0, "not downloaded yet"
|
||||
assert len(locals_hits) == 0, "not downloaded yet"
|
||||
|
||||
# test downloading the repo and observing a cache hit on second access
|
||||
existing_repos = get_huggingface_repo_list()
|
||||
assert test_repo_id not in existing_repos
|
||||
try:
|
||||
KNOWN_HUGGINGFACE_MODEL_REPOS.remove(test_repo_id)
|
||||
logging.error("unexpected, the test_repo_id was already in the KNOWN_HUGGINGFACE_MODEL_REPOS symbol")
|
||||
except KeyError:
|
||||
known_repos = get_huggingface_repo_list()
|
||||
assert test_repo_id not in known_repos
|
||||
|
||||
# best to import this at the time that it is run, not when the test is initialized
|
||||
KNOWN_HUGGINGFACE_MODEL_REPOS.add(test_repo_id)
|
||||
existing_repos = get_huggingface_repo_list()
|
||||
assert test_repo_id in existing_repos
|
||||
known_repos = get_huggingface_repo_list()
|
||||
assert test_repo_id in known_repos
|
||||
|
||||
cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id)
|
||||
assert len(cache_hits) == len(locals_hits) == 0, "not downloaded yet"
|
||||
|
||||
@ -10,6 +10,7 @@ from PIL import Image
|
||||
from freezegun import freeze_time
|
||||
|
||||
from comfy.cmd import folder_paths
|
||||
from comfy.component_model.executor_types import ValidateInputsTuple
|
||||
from comfy_extras.nodes.nodes_open_api import SaveImagesResponse, IntRequestParameter, FloatRequestParameter, \
|
||||
StringRequestParameter, HashImage, StringPosixPathJoin, LegacyOutputURIs, DevNullUris, StringJoin, StringToUri, \
|
||||
UriFormat, ImageExifMerge, ImageExifCreationDateAndBatchNumber, ImageExif, ImageExifUncommon, \
|
||||
@ -123,7 +124,36 @@ def test_string_enum_request_parameter():
|
||||
n = StringEnumRequestParameter()
|
||||
v, = n.execute(value="test", name="test")
|
||||
assert v == "test"
|
||||
# todo: check that a graph that uses this in a checkpoint is valid
|
||||
prompt = {
|
||||
"1": {
|
||||
"inputs": {
|
||||
"value": "euler",
|
||||
"name": "sampler_name",
|
||||
"title": "KSampler Node Sampler",
|
||||
"description":
|
||||
"This allows users to select a sampler for generating images with Latent Diffusion Models, including Stable Diffusion, ComfyUI, and SDXL. \n\nChange this only if explicitly requested by the user.\n\nList of sampler choice (this parameter): valid choices for scheduler (value for scheduler parameter).\n\n- euler: normal, karras, exponential, sgm_uniform, simple, ddim_uniform\n- euler_ancestral: normal, karras\n- heun: normal, karras\n- heunpp2: normal, karras\n- dpm_2: normal, karras\n- dpm_2_ancestral: normal, karras\n- lms: normal, karras\n- dpm_fast: normal, exponential\n- dpm_adaptive: normal, exponential\n- dpmpp_2s_ancestral: karras, exponential\n- dpmpp_sde: karras, exponential\n- dpmpp_sde_gpu: karras, exponential\n- dpmpp_2m: karras, sgm_uniform\n- dpmpp_2m_sde: karras, sgm_uniform\n- dpmpp_2m_sde_gpu: karras, sgm_uniform\n- dpmpp_3m_sde: karras, sgm_uniform\n- dpmpp_3m_sde_gpu: karras, sgm_uniform\n- ddpm: normal, simple\n- lcm: normal, exponential\n- ddim: normal, ddim_uniform\n- uni_pc: normal, karras, exponential\n- uni_pc_bh2: normal, karras, exponential",
|
||||
"__required": True,
|
||||
},
|
||||
"class_type": "StringEnumRequestParameter",
|
||||
"_meta": {
|
||||
"title": "StringEnumRequestParameter",
|
||||
},
|
||||
},
|
||||
"2": {
|
||||
"inputs": {
|
||||
"sampler_name": ["1", 0],
|
||||
},
|
||||
"class_type": "KSamplerSelect",
|
||||
"_meta": {
|
||||
"title": "KSamplerSelect",
|
||||
},
|
||||
},
|
||||
}
|
||||
from comfy.cmd.execution import validate_inputs
|
||||
validated: dict[str, ValidateInputsTuple] = {}
|
||||
validated["1"] = validate_inputs(prompt, "1", validated)
|
||||
validated["2"] = validate_inputs(prompt, "2", validated)
|
||||
assert validated["2"].valid
|
||||
|
||||
|
||||
@pytest.mark.skip("issues")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user