mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
improve tests
This commit is contained in:
parent
a2e898f091
commit
b81a5b15ae
@ -23,6 +23,7 @@ from .async_progress_iterable import QueuePromptWithProgress
|
||||
from .client_types import V1QueuePromptResponse
|
||||
from ..api.components.schema.prompt import PromptDict
|
||||
from ..cli_args_types import Configuration
|
||||
from ..cli_args import default_configuration
|
||||
from ..cmd.folder_paths import init_default_paths # pylint: disable=import-error
|
||||
from ..component_model.executor_types import ExecutorToClientProgress
|
||||
from ..component_model.make_mutable import make_mutable
|
||||
@ -176,7 +177,7 @@ class Comfy:
|
||||
In order to use this in blocking methods, learn more about asyncio online.
|
||||
"""
|
||||
|
||||
def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: ProcessPoolExecutor | ContextVarExecutor | Literal["ProcessPoolExecutor","ContextVarExecutor"] = None):
|
||||
def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: ProcessPoolExecutor | ContextVarExecutor | Literal["ProcessPoolExecutor", "ContextVarExecutor"] = None):
|
||||
self._progress_handler = progress_handler or ServerStub()
|
||||
self._owns_executor = executor is None or isinstance(executor, str)
|
||||
if self._owns_executor:
|
||||
@ -188,6 +189,7 @@ class Comfy:
|
||||
else:
|
||||
assert not isinstance(executor, str)
|
||||
self._executor = executor
|
||||
self._default_configuration = default_configuration()
|
||||
self._configuration = configuration
|
||||
self._is_running = False
|
||||
self._task_count_lock = RLock()
|
||||
@ -332,5 +334,9 @@ class Comfy:
|
||||
with self._task_count_lock:
|
||||
self._task_count -= 1
|
||||
|
||||
def __str__(self):
|
||||
diff = {k: v for k, v in (self._configuration or {}).items() if v != self._default_configuration.get(k)}
|
||||
return f"<Comfy task_count={self.task_count} configuration={diff} executor={self._executor}>"
|
||||
|
||||
|
||||
EmbeddedComfyClient = Comfy
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
# Original code can be found on: https://github.com/black-forest-labs/flux
|
||||
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from einops import rearrange, repeat
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange, repeat
|
||||
from ..common_dit import pad_to_patch_size
|
||||
from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP
|
||||
|
||||
from .layers import (
|
||||
@ -16,7 +18,6 @@ from .layers import (
|
||||
Modulation,
|
||||
RMSNorm
|
||||
)
|
||||
from .. import common_dit
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -33,7 +34,7 @@ class FluxParams:
|
||||
axes_dim: list
|
||||
theta: int
|
||||
patch_size: int
|
||||
qkv_bias: bool
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
txt_ids_dims: list
|
||||
global_modulation: bool = False
|
||||
@ -52,8 +53,6 @@ class Flux(nn.Module):
|
||||
|
||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
# todo: should this be here?
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
params = FluxParams(**kwargs)
|
||||
self.params = params
|
||||
@ -147,9 +146,6 @@ class Flux(nn.Module):
|
||||
|
||||
if transformer_options is None:
|
||||
transformer_options = {}
|
||||
if y is None:
|
||||
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
@ -273,10 +269,12 @@ class Flux(nn.Module):
|
||||
img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
|
||||
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options=None):
|
||||
if transformer_options is None:
|
||||
transformer_options = {}
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
x = common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
x = pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
|
||||
@ -794,7 +794,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
||||
quant_config = detect_layer_quantization(state_dict, unet_key_prefix)
|
||||
if quant_config:
|
||||
model_config.quant_config = quant_config
|
||||
logger.info("Detected mixed precision quantization")
|
||||
logger.debug("Detected mixed precision quantization")
|
||||
|
||||
if metadata is not None and "format" in metadata and metadata["format"] == "gguf":
|
||||
model_config.custom_operations = GGMLOps
|
||||
|
||||
@ -14,6 +14,7 @@ from pathlib import Path
|
||||
from typing import List, Optional, Final, Set
|
||||
|
||||
import requests
|
||||
import requests_cache
|
||||
import tqdm
|
||||
from huggingface_hub import dump_environment_info, hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem, CacheNotFound
|
||||
from huggingface_hub.utils import GatedRepoError, LocalEntryNotFoundError
|
||||
@ -137,37 +138,39 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
|
||||
path = None
|
||||
|
||||
cache_hit = False
|
||||
try:
|
||||
# always retrieve this from the cache if it already exists there
|
||||
path = hf_hub_download(repo_id=known_file.repo_id,
|
||||
filename=known_file.filename,
|
||||
repo_type=known_file.repo_type,
|
||||
revision=known_file.revision,
|
||||
local_files_only=True,
|
||||
local_dir=hf_destination_dir if args.force_hf_local_dir_mode else None,
|
||||
)
|
||||
logger.debug(f"hf_hub_download cache hit for {known_file.repo_id}/{known_file.filename}")
|
||||
cache_hit = True
|
||||
except LocalEntryNotFoundError:
|
||||
hf_hub_download_kwargs = dict(repo_id=known_file.repo_id,
|
||||
filename=known_file.filename,
|
||||
repo_type=known_file.repo_type,
|
||||
revision=known_file.revision,
|
||||
local_files_only=True,
|
||||
local_dir=hf_destination_dir if args.force_hf_local_dir_mode else None,
|
||||
token=True,
|
||||
)
|
||||
|
||||
with requests_cache.disabled():
|
||||
try:
|
||||
logger.debug(f"{folder_name}/{filename} is being downloaded from {known_file.repo_id}/{known_file.filename} candidate_str_match={candidate_str_match} candidate_filename_match={candidate_filename_match} candidate_alternate_filenames_match={candidate_alternate_filenames_match} candidate_save_filename_match={candidate_save_filename_match}")
|
||||
path = hf_hub_download(repo_id=known_file.repo_id,
|
||||
filename=known_file.filename,
|
||||
repo_type=known_file.repo_type,
|
||||
revision=known_file.revision,
|
||||
local_dir=hf_destination_dir if args.force_hf_local_dir_mode else None,
|
||||
)
|
||||
except requests.exceptions.HTTPError as exc_info:
|
||||
if exc_info.response.status_code == 401:
|
||||
raise GatedRepoError(f"{known_file.repo_id}/{known_file.filename}", response=exc_info.response)
|
||||
except IOError as exc_info:
|
||||
logger.error(f"cannot reach huggingface {known_file.repo_id}/{known_file.filename}", exc_info=exc_info)
|
||||
except Exception as exc_info:
|
||||
logger.error(f"an exception occurred while downloading {known_file.repo_id}/{known_file.filename}", exc_info=exc_info)
|
||||
dump_environment_info()
|
||||
for key, value in os.environ.items():
|
||||
if key.startswith("HF_XET"):
|
||||
print(f"{key}={value}", file=sys.stderr)
|
||||
# always retrieve this from the cache if it already exists there
|
||||
path = hf_hub_download(**hf_hub_download_kwargs)
|
||||
logger.debug(f"hf_hub_download cache hit for {known_file.repo_id}/{known_file.filename}")
|
||||
cache_hit = True
|
||||
except LocalEntryNotFoundError:
|
||||
try:
|
||||
logger.debug(f"{folder_name}/{filename} is being downloaded from {known_file.repo_id}/{known_file.filename} candidate_str_match={candidate_str_match} candidate_filename_match={candidate_filename_match} candidate_alternate_filenames_match={candidate_alternate_filenames_match} candidate_save_filename_match={candidate_save_filename_match}")
|
||||
hf_hub_download_kwargs.pop("local_files_only")
|
||||
path = hf_hub_download(**hf_hub_download_kwargs)
|
||||
except requests.exceptions.HTTPError as exc_info:
|
||||
if exc_info.response.status_code == 401:
|
||||
raise GatedRepoError(f"{known_file.repo_id}/{known_file.filename}", response=exc_info.response)
|
||||
except IOError as exc_info:
|
||||
logger.error(f"cannot reach huggingface {known_file.repo_id}/{known_file.filename}", exc_info=exc_info)
|
||||
except Exception as exc_info:
|
||||
logger.error(f"an exception occurred while downloading {known_file.repo_id}/{known_file.filename}. hf_hub_download kwargs={hf_hub_download_kwargs}", exc_info=exc_info)
|
||||
dump_environment_info()
|
||||
for key, value in os.environ.items():
|
||||
if key.startswith("HF_"):
|
||||
if key == "HF_TOKEN":
|
||||
value = "*****"
|
||||
print(f"{key}={value}", file=sys.stderr)
|
||||
|
||||
if path is not None and known_file.convert_to_16_bit and file_size is not None and file_size != 0:
|
||||
tensors = {}
|
||||
|
||||
@ -1,21 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import dataclasses
|
||||
import functools
|
||||
from os.path import split
|
||||
from pathlib import PurePosixPath
|
||||
from typing import Optional, List, Sequence, Union, Iterable, Protocol
|
||||
from typing import Optional, List, Sequence, Union, Iterable
|
||||
|
||||
from can_ada import parse, URL # pylint: disable=no-name-in-module
|
||||
from typing_extensions import TypedDict, NotRequired, runtime_checkable
|
||||
from typing_extensions import TypedDict, NotRequired
|
||||
|
||||
from .component_model.executor_types import ComboOptions
|
||||
from .component_model.files import canonicalize_path
|
||||
|
||||
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class UrlFile:
|
||||
_url: str
|
||||
|
||||
@ -116,7 +116,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
model_options = {}
|
||||
if special_tokens is None:
|
||||
special_tokens = {"start": 49406, "end": 49407, "pad": 49407}
|
||||
assert layer in self.LAYERS
|
||||
|
||||
if textmodel_json_config is None and "model_name" not in model_options:
|
||||
model_options = {**model_options, "model_name": "clip_l"}
|
||||
|
||||
@ -1539,6 +1539,6 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
|
||||
if quant_metadata is not None:
|
||||
layers = quant_metadata["layers"]
|
||||
for k, v in layers.items():
|
||||
state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(json.dumps(v).encode('utf-8'), dtype=torch.uint8)
|
||||
state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(bytearray(json.dumps(v).encode('utf-8')), dtype=torch.uint8)
|
||||
|
||||
return state_dict, metadata
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import importlib.resources
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from importlib.abc import Traversable
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
@ -45,7 +46,7 @@ def _generate_config_params():
|
||||
{"fast": set()},
|
||||
# {"fast": {PerformanceFeature.Fp16Accumulation}},
|
||||
# {"fast": {PerformanceFeature.Fp8MatrixMultiplication}},
|
||||
# {"fast": {PerformanceFeature.CublasOps}},
|
||||
{"fast": {PerformanceFeature.CublasOps}},
|
||||
]
|
||||
|
||||
for attn, asnc, pinned, fst in itertools.product(attn_options, async_options, pinned_options, fast_options):
|
||||
@ -57,7 +58,7 @@ def _generate_config_params():
|
||||
yield config_update
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=False, params=_generate_config_params())
|
||||
@pytest.fixture(scope="function", autouse=False, params=_generate_config_params(), ids=lambda p: ",".join(f"{k}={v}" for k, v in p.items()))
|
||||
async def client(tmp_path_factory, request) -> AsyncGenerator[Any, Any]:
|
||||
config = default_configuration()
|
||||
# this should help things go a little faster
|
||||
@ -91,10 +92,15 @@ async def test_workflow(workflow_name: str, workflow_file: Traversable, has_gpu:
|
||||
prompt = Prompt.validate(workflow)
|
||||
# todo: add all the models we want to test a bit m2ore elegantly
|
||||
outputs = {}
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
outputs = await client.queue_prompt(prompt)
|
||||
except TorchAudioNotFoundError:
|
||||
pytest.skip("requires torchaudio")
|
||||
finally:
|
||||
end_time = time.time()
|
||||
logger.info(f"Test {workflow_name} with client {client} took {end_time - start_time:.4f}s")
|
||||
|
||||
if any(v.class_type == "SaveImage" for v in prompt.values()):
|
||||
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage")
|
||||
|
||||
@ -33,7 +33,7 @@
|
||||
"4": {
|
||||
"inputs": {
|
||||
"prompt": "Who is this?",
|
||||
"chat_template": "llava-v1.6-mistral-7b-hf",
|
||||
"chat_template": "default",
|
||||
"model": [
|
||||
"1",
|
||||
0
|
||||
|
||||
Loading…
Reference in New Issue
Block a user