diff --git a/comfy/client/embedded_comfy_client.py b/comfy/client/embedded_comfy_client.py index cd1a6fdbd..ee178cea5 100644 --- a/comfy/client/embedded_comfy_client.py +++ b/comfy/client/embedded_comfy_client.py @@ -23,6 +23,7 @@ from .async_progress_iterable import QueuePromptWithProgress from .client_types import V1QueuePromptResponse from ..api.components.schema.prompt import PromptDict from ..cli_args_types import Configuration +from ..cli_args import default_configuration from ..cmd.folder_paths import init_default_paths # pylint: disable=import-error from ..component_model.executor_types import ExecutorToClientProgress from ..component_model.make_mutable import make_mutable @@ -176,7 +177,7 @@ class Comfy: In order to use this in blocking methods, learn more about asyncio online. """ - def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: ProcessPoolExecutor | ContextVarExecutor | Literal["ProcessPoolExecutor","ContextVarExecutor"] = None): + def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: ProcessPoolExecutor | ContextVarExecutor | Literal["ProcessPoolExecutor", "ContextVarExecutor"] = None): self._progress_handler = progress_handler or ServerStub() self._owns_executor = executor is None or isinstance(executor, str) if self._owns_executor: @@ -188,6 +189,7 @@ class Comfy: else: assert not isinstance(executor, str) self._executor = executor + self._default_configuration = default_configuration() self._configuration = configuration self._is_running = False self._task_count_lock = RLock() @@ -332,5 +334,9 @@ class Comfy: with self._task_count_lock: self._task_count -= 1 + def __str__(self): + diff = {k: v for k, v in (self._configuration or {}).items() if v != self._default_configuration.get(k)} + return f"" + EmbeddedComfyClient = Comfy diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 0d0c88ce6..f1757f95b 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -1,9 +1,11 @@ # Original code can be found on: https://github.com/black-forest-labs/flux -import torch from dataclasses import dataclass -from einops import rearrange, repeat + +import torch from torch import Tensor, nn +from einops import rearrange, repeat +from ..common_dit import pad_to_patch_size from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP from .layers import ( @@ -16,7 +18,6 @@ from .layers import ( Modulation, RMSNorm ) -from .. import common_dit @dataclass @@ -33,7 +34,7 @@ class FluxParams: axes_dim: list theta: int patch_size: int - qkv_bias: bool + qkv_bias: bool guidance_embed: bool txt_ids_dims: list global_modulation: bool = False @@ -52,8 +53,6 @@ class Flux(nn.Module): def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs): super().__init__() - # todo: should this be here? - self.device = device self.dtype = dtype params = FluxParams(**kwargs) self.params = params @@ -147,9 +146,6 @@ class Flux(nn.Module): if transformer_options is None: transformer_options = {} - if y is None: - y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) - patches = transformer_options.get("patches", {}) patches_replace = transformer_options.get("patches_replace", {}) if img.ndim != 3 or txt.ndim != 3: @@ -273,10 +269,12 @@ class Flux(nn.Module): img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels) return img - def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}): + def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options=None): + if transformer_options is None: + transformer_options = {} bs, c, h, w = x.shape patch_size = self.patch_size - x = common_dit.pad_to_patch_size(x, (patch_size, patch_size)) + x = pad_to_patch_size(x, (patch_size, patch_size)) img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) h_len = ((h + (patch_size // 2)) // patch_size) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 982ba43bb..e29655d65 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -794,7 +794,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal quant_config = detect_layer_quantization(state_dict, unet_key_prefix) if quant_config: model_config.quant_config = quant_config - logger.info("Detected mixed precision quantization") + logger.debug("Detected mixed precision quantization") if metadata is not None and "format" in metadata and metadata["format"] == "gguf": model_config.custom_operations = GGMLOps diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index 37d0987f2..8082536db 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -14,6 +14,7 @@ from pathlib import Path from typing import List, Optional, Final, Set import requests +import requests_cache import tqdm from huggingface_hub import dump_environment_info, hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem, CacheNotFound from huggingface_hub.utils import GatedRepoError, LocalEntryNotFoundError @@ -137,37 +138,39 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[ path = None cache_hit = False - try: - # always retrieve this from the cache if it already exists there - path = hf_hub_download(repo_id=known_file.repo_id, - filename=known_file.filename, - repo_type=known_file.repo_type, - revision=known_file.revision, - local_files_only=True, - local_dir=hf_destination_dir if args.force_hf_local_dir_mode else None, - ) - logger.debug(f"hf_hub_download cache hit for {known_file.repo_id}/{known_file.filename}") - cache_hit = True - except LocalEntryNotFoundError: + hf_hub_download_kwargs = dict(repo_id=known_file.repo_id, + filename=known_file.filename, + repo_type=known_file.repo_type, + revision=known_file.revision, + local_files_only=True, + local_dir=hf_destination_dir if args.force_hf_local_dir_mode else None, + token=True, + ) + + with requests_cache.disabled(): try: - logger.debug(f"{folder_name}/{filename} is being downloaded from {known_file.repo_id}/{known_file.filename} candidate_str_match={candidate_str_match} candidate_filename_match={candidate_filename_match} candidate_alternate_filenames_match={candidate_alternate_filenames_match} candidate_save_filename_match={candidate_save_filename_match}") - path = hf_hub_download(repo_id=known_file.repo_id, - filename=known_file.filename, - repo_type=known_file.repo_type, - revision=known_file.revision, - local_dir=hf_destination_dir if args.force_hf_local_dir_mode else None, - ) - except requests.exceptions.HTTPError as exc_info: - if exc_info.response.status_code == 401: - raise GatedRepoError(f"{known_file.repo_id}/{known_file.filename}", response=exc_info.response) - except IOError as exc_info: - logger.error(f"cannot reach huggingface {known_file.repo_id}/{known_file.filename}", exc_info=exc_info) - except Exception as exc_info: - logger.error(f"an exception occurred while downloading {known_file.repo_id}/{known_file.filename}", exc_info=exc_info) - dump_environment_info() - for key, value in os.environ.items(): - if key.startswith("HF_XET"): - print(f"{key}={value}", file=sys.stderr) + # always retrieve this from the cache if it already exists there + path = hf_hub_download(**hf_hub_download_kwargs) + logger.debug(f"hf_hub_download cache hit for {known_file.repo_id}/{known_file.filename}") + cache_hit = True + except LocalEntryNotFoundError: + try: + logger.debug(f"{folder_name}/{filename} is being downloaded from {known_file.repo_id}/{known_file.filename} candidate_str_match={candidate_str_match} candidate_filename_match={candidate_filename_match} candidate_alternate_filenames_match={candidate_alternate_filenames_match} candidate_save_filename_match={candidate_save_filename_match}") + hf_hub_download_kwargs.pop("local_files_only") + path = hf_hub_download(**hf_hub_download_kwargs) + except requests.exceptions.HTTPError as exc_info: + if exc_info.response.status_code == 401: + raise GatedRepoError(f"{known_file.repo_id}/{known_file.filename}", response=exc_info.response) + except IOError as exc_info: + logger.error(f"cannot reach huggingface {known_file.repo_id}/{known_file.filename}", exc_info=exc_info) + except Exception as exc_info: + logger.error(f"an exception occurred while downloading {known_file.repo_id}/{known_file.filename}. hf_hub_download kwargs={hf_hub_download_kwargs}", exc_info=exc_info) + dump_environment_info() + for key, value in os.environ.items(): + if key.startswith("HF_"): + if key == "HF_TOKEN": + value = "*****" + print(f"{key}={value}", file=sys.stderr) if path is not None and known_file.convert_to_16_bit and file_size is not None and file_size != 0: tensors = {} diff --git a/comfy/model_downloader_types.py b/comfy/model_downloader_types.py index d082ab1f9..0b69cca72 100644 --- a/comfy/model_downloader_types.py +++ b/comfy/model_downloader_types.py @@ -1,21 +1,18 @@ from __future__ import annotations -import collections import dataclasses import functools from os.path import split from pathlib import PurePosixPath -from typing import Optional, List, Sequence, Union, Iterable, Protocol +from typing import Optional, List, Sequence, Union, Iterable from can_ada import parse, URL # pylint: disable=no-name-in-module -from typing_extensions import TypedDict, NotRequired, runtime_checkable +from typing_extensions import TypedDict, NotRequired from .component_model.executor_types import ComboOptions from .component_model.files import canonicalize_path - - @dataclasses.dataclass(frozen=True) class UrlFile: _url: str diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 75a8f79e7..8404c5a3f 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -116,7 +116,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): model_options = {} if special_tokens is None: special_tokens = {"start": 49406, "end": 49407, "pad": 49407} - assert layer in self.LAYERS if textmodel_json_config is None and "model_name" not in model_options: model_options = {**model_options, "model_name": "clip_l"} diff --git a/comfy/utils.py b/comfy/utils.py index 853850b34..d1c546691 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1539,6 +1539,6 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}): if quant_metadata is not None: layers = quant_metadata["layers"] for k, v in layers.items(): - state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(json.dumps(v).encode('utf-8'), dtype=torch.uint8) + state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(bytearray(json.dumps(v).encode('utf-8')), dtype=torch.uint8) return state_dict, metadata diff --git a/tests/inference/test_workflows.py b/tests/inference/test_workflows.py index afa3b6d5d..b62a8892c 100644 --- a/tests/inference/test_workflows.py +++ b/tests/inference/test_workflows.py @@ -1,6 +1,7 @@ import importlib.resources import json import logging +import time from importlib.abc import Traversable from typing import Any, AsyncGenerator @@ -45,7 +46,7 @@ def _generate_config_params(): {"fast": set()}, # {"fast": {PerformanceFeature.Fp16Accumulation}}, # {"fast": {PerformanceFeature.Fp8MatrixMultiplication}}, - # {"fast": {PerformanceFeature.CublasOps}}, + {"fast": {PerformanceFeature.CublasOps}}, ] for attn, asnc, pinned, fst in itertools.product(attn_options, async_options, pinned_options, fast_options): @@ -57,7 +58,7 @@ def _generate_config_params(): yield config_update -@pytest.fixture(scope="function", autouse=False, params=_generate_config_params()) +@pytest.fixture(scope="function", autouse=False, params=_generate_config_params(), ids=lambda p: ",".join(f"{k}={v}" for k, v in p.items())) async def client(tmp_path_factory, request) -> AsyncGenerator[Any, Any]: config = default_configuration() # this should help things go a little faster @@ -91,10 +92,15 @@ async def test_workflow(workflow_name: str, workflow_file: Traversable, has_gpu: prompt = Prompt.validate(workflow) # todo: add all the models we want to test a bit m2ore elegantly outputs = {} + + start_time = time.time() try: outputs = await client.queue_prompt(prompt) except TorchAudioNotFoundError: pytest.skip("requires torchaudio") + finally: + end_time = time.time() + logger.info(f"Test {workflow_name} with client {client} took {end_time - start_time:.4f}s") if any(v.class_type == "SaveImage" for v in prompt.values()): save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage") diff --git a/tests/inference/workflows/llava-0.json b/tests/inference/workflows/llava-0.json index de3ea18c9..09596f743 100644 --- a/tests/inference/workflows/llava-0.json +++ b/tests/inference/workflows/llava-0.json @@ -33,7 +33,7 @@ "4": { "inputs": { "prompt": "Who is this?", - "chat_template": "llava-v1.6-mistral-7b-hf", + "chat_template": "default", "model": [ "1", 0