Fixes for tests and completing merge

- huggingface cache is now better used on platforms that support
   symlinking and the files you are requesting already exist in the
   cache
 - absolute imports were changed to relative in the correct places
 - StringEnumRequestParameter has a special case in validation
 - fix model_management whitespace issue
 - fix comfy.ops references
This commit is contained in:
doctorpangloss 2024-08-01 18:28:51 -07:00
parent a44a039661
commit d9ba795385
9 changed files with 123 additions and 55 deletions

View File

@ -648,7 +648,8 @@ def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTupl
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
received_type = r[val[1]]
received_types[x] = received_type
if 'input_types' not in validate_function_inputs and received_type != type_input:
any_enum = received_type == [] and (isinstance(type_input, list) or isinstance(type_input, tuple))
if 'input_types' not in validate_function_inputs and received_type != type_input and not any_enum:
details = f"{x}, {received_type} != {type_input}"
error = {
"type": "return_type_mismatch",

View File

@ -4,7 +4,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from .... import ops
from ... import ops
from ..modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
from ..modules.diffusionmodules.util import timestep_embedding
from torch.utils import checkpoint

View File

@ -1,8 +1,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
import comfy.ops
from ..modules.attention import optimized_attention
from ... import ops
class AttentionPool(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None, dtype=None, device=None, operations=None):
@ -19,7 +19,7 @@ class AttentionPool(nn.Module):
x = x[:,:self.positional_embedding.shape[0] - 1]
x = x.permute(1, 0, 2) # NLC -> LNC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
x = x + comfy.ops.cast_to_input(self.positional_embedding[:, None, :], x) # (L+1)NC
x = x + ops.cast_to_input(self.positional_embedding[:, None, :], x) # (L+1)NC
q = self.q_proj(x[:1])
k = self.k_proj(x)

View File

@ -4,15 +4,17 @@ import collections
import logging
import operator
import os
import shutil
from functools import reduce
from itertools import chain
from os.path import join
from pathlib import Path
from typing import List, Any, Optional, Sequence, Final, Set, MutableSequence
from typing import List, Optional, Sequence, Final, Set, MutableSequence
import tqdm
from huggingface_hub import hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem
from huggingface_hub.utils import GatedRepoError
from huggingface_hub.file_download import are_symlinks_supported
from huggingface_hub.utils import GatedRepoError, LocalEntryNotFoundError
from requests import Session
from safetensors import safe_open
from safetensors.torch import save_file
@ -82,16 +84,31 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
file_size = os.stat(path, follow_symlinks=True).st_size if os.path.isfile(path) else None
except:
file_size = None
if os.path.isfile(path) and known_file.size is None or file_size == known_file.size:
if os.path.isfile(path) and file_size == known_file.size:
return path
path = hf_hub_download(repo_id=known_file.repo_id,
filename=known_file.filename,
# todo: in the latest huggingface implementation, this causes files to be downloaded as though the destination is the cache dir, rather than a local directory linking to a cache dir
local_dir=hf_destination_dir,
repo_type=known_file.repo_type,
revision=known_file.revision,
)
cache_hit = False
try:
if not are_symlinks_supported():
raise PermissionError("no symlink support")
# always retrieve this from the cache if it already exists there
path = hf_hub_download(repo_id=known_file.repo_id,
filename=known_file.filename,
repo_type=known_file.repo_type,
revision=known_file.revision,
local_files_only=True,
)
logging.info(f"hf_hub_download cache hit for {known_file.repo_id}/{known_file.filename}")
if linked_filename is None:
linked_filename = known_file.filename
cache_hit = True
except (LocalEntryNotFoundError, PermissionError):
path = hf_hub_download(repo_id=known_file.repo_id,
filename=known_file.filename,
local_dir=hf_destination_dir,
repo_type=known_file.repo_type,
revision=known_file.revision,
)
if known_file.convert_to_16_bit and file_size is not None and file_size != 0:
tensors = {}
@ -103,17 +120,30 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
del x
pb.update()
save_file(tensors, path)
# always save converted files to the destination so that the huggingface cache is not corrupted
save_file(tensors, os.path.join(hf_destination_dir, known_file.filename))
for _, v in tensors.items():
del v
logging.info(f"Converted {path} to 16 bit, size is {os.stat(path, follow_symlinks=True).st_size}")
try:
if linked_filename is not None:
os.symlink(os.path.join(hf_destination_dir, known_file.filename), os.path.join(this_model_directory, linked_filename))
except Exception as exc_info:
logging.error(f"Failed to link file with alternative download save name in a way that is compatible with Hugging Face caching {repr(known_file)}", exc_info=exc_info)
link_exc_info = None
if linked_filename is not None:
destination_link = os.path.join(this_model_directory, linked_filename)
try:
os.makedirs(this_model_directory, exist_ok=True)
os.symlink(path, destination_link)
except WindowsError:
try:
os.link(path, destination_link)
except Exception as exc_info:
link_exc_info = exc_info
if cache_hit:
shutil.copyfile(path, destination_link)
except Exception as exc_info:
link_exc_info = exc_info
if link_exc_info is not None:
logging.error(f"Failed to link file with alternative download save name in a way that is compatible with Hugging Face caching {repr(known_file)}. If cache_hit={cache_hit} is True, the file was copied into the destination.", exc_info=exc_info)
else:
url: Optional[str] = None
save_filename = known_file.save_with_filename or known_file.filename
@ -463,7 +493,7 @@ def get_huggingface_repo_list(*extra_cache_dirs: str) -> List[str]:
cache_item.repo_id for cache_item in \
reduce(operator.or_,
map(lambda cache_info: cache_info.repos, [scan_cache_dir()] + [scan_cache_dir(cache_dir=cache_dir) for cache_dir in extra_cache_dirs if os.path.isdir(cache_dir)]))
if cache_item.repo_type == "model"
if cache_item.repo_type == "model" or cache_item.repo_type == "space"
)
# also check local-dir style directories
@ -489,9 +519,9 @@ def get_huggingface_repo_list(*extra_cache_dirs: str) -> List[str]:
return list(existing_repo_ids | existing_local_dir_repos | known_repo_ids)
def get_or_download_huggingface_repo(repo_id: str) -> Optional[str]:
cache_dirs = folder_paths.get_folder_paths("huggingface_cache")
local_dirs = folder_paths.get_folder_paths("huggingface")
def get_or_download_huggingface_repo(repo_id: str, cache_dirs: Optional[list] = None, local_dirs: Optional[list] = None) -> Optional[str]:
cache_dirs = cache_dirs or folder_paths.get_folder_paths("huggingface_cache")
local_dirs = local_dirs or folder_paths.get_folder_paths("huggingface")
cache_dirs_snapshots, local_dirs_snapshots = _get_cache_hits(cache_dirs, local_dirs, repo_id)
local_dirs_cache_hit = len(local_dirs_snapshots) > 0

View File

@ -453,28 +453,28 @@ def load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0,
else:
minimum_memory_required = max(inference_memory, minimum_memory_required)
models = set(models)
models_to_load = []
models_already_loaded = []
for x in models:
loaded_model = LoadedModel(x)
loaded = None
models = set(models)
models_to_load = []
models_already_loaded = []
for x in models:
loaded_model = LoadedModel(x)
loaded = None
try:
loaded_model_index = current_loaded_models.index(loaded_model)
except ValueError:
loaded_model_index = None
try:
loaded_model_index = current_loaded_models.index(loaded_model)
except ValueError:
loaded_model_index = None
if loaded_model_index is not None:
loaded = current_loaded_models[loaded_model_index]
if loaded.should_reload_model(force_patch_weights=force_patch_weights): # TODO: cleanup this model reload logic
current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True)
loaded = None
else:
loaded.currently_used = True
models_already_loaded.append(loaded)
if loaded is None:
models_to_load.append(loaded_model)
if loaded_model_index is not None:
loaded = current_loaded_models[loaded_model_index]
if loaded.should_reload_model(force_patch_weights=force_patch_weights): # TODO: cleanup this model reload logic
current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True)
loaded = None
else:
loaded.currently_used = True
models_already_loaded.append(loaded)
if loaded is None:
models_to_load.append(loaded_model)
models_freed: List[LoadedModel] = []
try:
@ -513,7 +513,6 @@ def load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0,
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required)))
if model_size <= lowvram_model_memory: # only switch to lowvram if really necessary
lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM:

View File

View File

@ -89,9 +89,13 @@ def has_gpu() -> bool:
@pytest.fixture(scope="module", autouse=False)
def frontend_backend_worker_with_rabbitmq(tmp_path_factory) -> str:
"""
starts a frontend and backend worker against a started rabbitmq, and yields the address of the frontend
populates the cache with the sdxl checkpoints, starts a frontend and backend worker against a started rabbitmq, and yields the address of the frontend
:return:
"""
from huggingface_hub import hf_hub_download
hf_hub_download("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors")
hf_hub_download("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors")
tmp_path = tmp_path_factory.mktemp("comfy_background_server")
processes_to_close: List[subprocess.Popen] = []
from testcontainers.rabbitmq import RabbitMqContainer

View File

@ -1,3 +1,4 @@
import logging
import os
import shutil
@ -45,7 +46,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
@pytest.mark.asyncio
@pytest.mark.skip("flakey")
def test_known_repos(tmp_path_factory):
prev_hub_cache = os.getenv("HF_HUB_CACHE")
os.environ["HF_HUB_CACHE"] = str(tmp_path_factory.mktemp("huggingface_root_cache"))
@ -66,21 +66,25 @@ def test_known_repos(tmp_path_factory):
_delete_repo_from_huggingface_cache(test_repo_id, test_cache_dir)
args.disable_known_models = False
try:
folder_paths.folder_names_and_paths["huggingface"] += FolderPathsTuple("huggingface", [test_local_dir], {""})
folder_paths.folder_names_and_paths["huggingface_cache"] += FolderPathsTuple("huggingface_cache", [test_cache_dir], {""})
folder_paths.folder_names_and_paths["huggingface"] = FolderPathsTuple("huggingface", [test_local_dir], {""})
folder_paths.folder_names_and_paths["huggingface_cache"] = FolderPathsTuple("huggingface_cache", [test_cache_dir], {""})
cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id)
assert len(cache_hits) == 0, "not downloaded yet"
assert len(locals_hits) == 0, "not downloaded yet"
# test downloading the repo and observing a cache hit on second access
existing_repos = get_huggingface_repo_list()
assert test_repo_id not in existing_repos
try:
KNOWN_HUGGINGFACE_MODEL_REPOS.remove(test_repo_id)
logging.error("unexpected, the test_repo_id was already in the KNOWN_HUGGINGFACE_MODEL_REPOS symbol")
except KeyError:
known_repos = get_huggingface_repo_list()
assert test_repo_id not in known_repos
# best to import this at the time that it is run, not when the test is initialized
KNOWN_HUGGINGFACE_MODEL_REPOS.add(test_repo_id)
existing_repos = get_huggingface_repo_list()
assert test_repo_id in existing_repos
known_repos = get_huggingface_repo_list()
assert test_repo_id in known_repos
cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id)
assert len(cache_hits) == len(locals_hits) == 0, "not downloaded yet"

View File

@ -10,6 +10,7 @@ from PIL import Image
from freezegun import freeze_time
from comfy.cmd import folder_paths
from comfy.component_model.executor_types import ValidateInputsTuple
from comfy_extras.nodes.nodes_open_api import SaveImagesResponse, IntRequestParameter, FloatRequestParameter, \
StringRequestParameter, HashImage, StringPosixPathJoin, LegacyOutputURIs, DevNullUris, StringJoin, StringToUri, \
UriFormat, ImageExifMerge, ImageExifCreationDateAndBatchNumber, ImageExif, ImageExifUncommon, \
@ -123,7 +124,36 @@ def test_string_enum_request_parameter():
n = StringEnumRequestParameter()
v, = n.execute(value="test", name="test")
assert v == "test"
# todo: check that a graph that uses this in a checkpoint is valid
prompt = {
"1": {
"inputs": {
"value": "euler",
"name": "sampler_name",
"title": "KSampler Node Sampler",
"description":
"This allows users to select a sampler for generating images with Latent Diffusion Models, including Stable Diffusion, ComfyUI, and SDXL. \n\nChange this only if explicitly requested by the user.\n\nList of sampler choice (this parameter): valid choices for scheduler (value for scheduler parameter).\n\n- euler: normal, karras, exponential, sgm_uniform, simple, ddim_uniform\n- euler_ancestral: normal, karras\n- heun: normal, karras\n- heunpp2: normal, karras\n- dpm_2: normal, karras\n- dpm_2_ancestral: normal, karras\n- lms: normal, karras\n- dpm_fast: normal, exponential\n- dpm_adaptive: normal, exponential\n- dpmpp_2s_ancestral: karras, exponential\n- dpmpp_sde: karras, exponential\n- dpmpp_sde_gpu: karras, exponential\n- dpmpp_2m: karras, sgm_uniform\n- dpmpp_2m_sde: karras, sgm_uniform\n- dpmpp_2m_sde_gpu: karras, sgm_uniform\n- dpmpp_3m_sde: karras, sgm_uniform\n- dpmpp_3m_sde_gpu: karras, sgm_uniform\n- ddpm: normal, simple\n- lcm: normal, exponential\n- ddim: normal, ddim_uniform\n- uni_pc: normal, karras, exponential\n- uni_pc_bh2: normal, karras, exponential",
"__required": True,
},
"class_type": "StringEnumRequestParameter",
"_meta": {
"title": "StringEnumRequestParameter",
},
},
"2": {
"inputs": {
"sampler_name": ["1", 0],
},
"class_type": "KSamplerSelect",
"_meta": {
"title": "KSamplerSelect",
},
},
}
from comfy.cmd.execution import validate_inputs
validated: dict[str, ValidateInputsTuple] = {}
validated["1"] = validate_inputs(prompt, "1", validated)
validated["2"] = validate_inputs(prompt, "2", validated)
assert validated["2"].valid
@pytest.mark.skip("issues")