mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
improve tests
This commit is contained in:
parent
a2e898f091
commit
b81a5b15ae
@ -23,6 +23,7 @@ from .async_progress_iterable import QueuePromptWithProgress
|
|||||||
from .client_types import V1QueuePromptResponse
|
from .client_types import V1QueuePromptResponse
|
||||||
from ..api.components.schema.prompt import PromptDict
|
from ..api.components.schema.prompt import PromptDict
|
||||||
from ..cli_args_types import Configuration
|
from ..cli_args_types import Configuration
|
||||||
|
from ..cli_args import default_configuration
|
||||||
from ..cmd.folder_paths import init_default_paths # pylint: disable=import-error
|
from ..cmd.folder_paths import init_default_paths # pylint: disable=import-error
|
||||||
from ..component_model.executor_types import ExecutorToClientProgress
|
from ..component_model.executor_types import ExecutorToClientProgress
|
||||||
from ..component_model.make_mutable import make_mutable
|
from ..component_model.make_mutable import make_mutable
|
||||||
@ -176,7 +177,7 @@ class Comfy:
|
|||||||
In order to use this in blocking methods, learn more about asyncio online.
|
In order to use this in blocking methods, learn more about asyncio online.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: ProcessPoolExecutor | ContextVarExecutor | Literal["ProcessPoolExecutor","ContextVarExecutor"] = None):
|
def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: ProcessPoolExecutor | ContextVarExecutor | Literal["ProcessPoolExecutor", "ContextVarExecutor"] = None):
|
||||||
self._progress_handler = progress_handler or ServerStub()
|
self._progress_handler = progress_handler or ServerStub()
|
||||||
self._owns_executor = executor is None or isinstance(executor, str)
|
self._owns_executor = executor is None or isinstance(executor, str)
|
||||||
if self._owns_executor:
|
if self._owns_executor:
|
||||||
@ -188,6 +189,7 @@ class Comfy:
|
|||||||
else:
|
else:
|
||||||
assert not isinstance(executor, str)
|
assert not isinstance(executor, str)
|
||||||
self._executor = executor
|
self._executor = executor
|
||||||
|
self._default_configuration = default_configuration()
|
||||||
self._configuration = configuration
|
self._configuration = configuration
|
||||||
self._is_running = False
|
self._is_running = False
|
||||||
self._task_count_lock = RLock()
|
self._task_count_lock = RLock()
|
||||||
@ -332,5 +334,9 @@ class Comfy:
|
|||||||
with self._task_count_lock:
|
with self._task_count_lock:
|
||||||
self._task_count -= 1
|
self._task_count -= 1
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
diff = {k: v for k, v in (self._configuration or {}).items() if v != self._default_configuration.get(k)}
|
||||||
|
return f"<Comfy task_count={self.task_count} configuration={diff} executor={self._executor}>"
|
||||||
|
|
||||||
|
|
||||||
EmbeddedComfyClient = Comfy
|
EmbeddedComfyClient = Comfy
|
||||||
|
|||||||
@ -1,9 +1,11 @@
|
|||||||
# Original code can be found on: https://github.com/black-forest-labs/flux
|
# Original code can be found on: https://github.com/black-forest-labs/flux
|
||||||
|
|
||||||
import torch
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from einops import rearrange, repeat
|
|
||||||
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from ..common_dit import pad_to_patch_size
|
||||||
from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP
|
from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP
|
||||||
|
|
||||||
from .layers import (
|
from .layers import (
|
||||||
@ -16,7 +18,6 @@ from .layers import (
|
|||||||
Modulation,
|
Modulation,
|
||||||
RMSNorm
|
RMSNorm
|
||||||
)
|
)
|
||||||
from .. import common_dit
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -52,8 +53,6 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# todo: should this be here?
|
|
||||||
self.device = device
|
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
params = FluxParams(**kwargs)
|
params = FluxParams(**kwargs)
|
||||||
self.params = params
|
self.params = params
|
||||||
@ -147,9 +146,6 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
if transformer_options is None:
|
if transformer_options is None:
|
||||||
transformer_options = {}
|
transformer_options = {}
|
||||||
if y is None:
|
|
||||||
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
|
||||||
|
|
||||||
patches = transformer_options.get("patches", {})
|
patches = transformer_options.get("patches", {})
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
if img.ndim != 3 or txt.ndim != 3:
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
@ -273,10 +269,12 @@ class Flux(nn.Module):
|
|||||||
img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
|
img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
|
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options=None):
|
||||||
|
if transformer_options is None:
|
||||||
|
transformer_options = {}
|
||||||
bs, c, h, w = x.shape
|
bs, c, h, w = x.shape
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
x = common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
x = pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
|
|
||||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||||
|
|||||||
@ -794,7 +794,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
|||||||
quant_config = detect_layer_quantization(state_dict, unet_key_prefix)
|
quant_config = detect_layer_quantization(state_dict, unet_key_prefix)
|
||||||
if quant_config:
|
if quant_config:
|
||||||
model_config.quant_config = quant_config
|
model_config.quant_config = quant_config
|
||||||
logger.info("Detected mixed precision quantization")
|
logger.debug("Detected mixed precision quantization")
|
||||||
|
|
||||||
if metadata is not None and "format" in metadata and metadata["format"] == "gguf":
|
if metadata is not None and "format" in metadata and metadata["format"] == "gguf":
|
||||||
model_config.custom_operations = GGMLOps
|
model_config.custom_operations = GGMLOps
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from pathlib import Path
|
|||||||
from typing import List, Optional, Final, Set
|
from typing import List, Optional, Final, Set
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
import requests_cache
|
||||||
import tqdm
|
import tqdm
|
||||||
from huggingface_hub import dump_environment_info, hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem, CacheNotFound
|
from huggingface_hub import dump_environment_info, hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem, CacheNotFound
|
||||||
from huggingface_hub.utils import GatedRepoError, LocalEntryNotFoundError
|
from huggingface_hub.utils import GatedRepoError, LocalEntryNotFoundError
|
||||||
@ -137,36 +138,38 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
|
|||||||
path = None
|
path = None
|
||||||
|
|
||||||
cache_hit = False
|
cache_hit = False
|
||||||
try:
|
hf_hub_download_kwargs = dict(repo_id=known_file.repo_id,
|
||||||
# always retrieve this from the cache if it already exists there
|
|
||||||
path = hf_hub_download(repo_id=known_file.repo_id,
|
|
||||||
filename=known_file.filename,
|
filename=known_file.filename,
|
||||||
repo_type=known_file.repo_type,
|
repo_type=known_file.repo_type,
|
||||||
revision=known_file.revision,
|
revision=known_file.revision,
|
||||||
local_files_only=True,
|
local_files_only=True,
|
||||||
local_dir=hf_destination_dir if args.force_hf_local_dir_mode else None,
|
local_dir=hf_destination_dir if args.force_hf_local_dir_mode else None,
|
||||||
|
token=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with requests_cache.disabled():
|
||||||
|
try:
|
||||||
|
# always retrieve this from the cache if it already exists there
|
||||||
|
path = hf_hub_download(**hf_hub_download_kwargs)
|
||||||
logger.debug(f"hf_hub_download cache hit for {known_file.repo_id}/{known_file.filename}")
|
logger.debug(f"hf_hub_download cache hit for {known_file.repo_id}/{known_file.filename}")
|
||||||
cache_hit = True
|
cache_hit = True
|
||||||
except LocalEntryNotFoundError:
|
except LocalEntryNotFoundError:
|
||||||
try:
|
try:
|
||||||
logger.debug(f"{folder_name}/{filename} is being downloaded from {known_file.repo_id}/{known_file.filename} candidate_str_match={candidate_str_match} candidate_filename_match={candidate_filename_match} candidate_alternate_filenames_match={candidate_alternate_filenames_match} candidate_save_filename_match={candidate_save_filename_match}")
|
logger.debug(f"{folder_name}/{filename} is being downloaded from {known_file.repo_id}/{known_file.filename} candidate_str_match={candidate_str_match} candidate_filename_match={candidate_filename_match} candidate_alternate_filenames_match={candidate_alternate_filenames_match} candidate_save_filename_match={candidate_save_filename_match}")
|
||||||
path = hf_hub_download(repo_id=known_file.repo_id,
|
hf_hub_download_kwargs.pop("local_files_only")
|
||||||
filename=known_file.filename,
|
path = hf_hub_download(**hf_hub_download_kwargs)
|
||||||
repo_type=known_file.repo_type,
|
|
||||||
revision=known_file.revision,
|
|
||||||
local_dir=hf_destination_dir if args.force_hf_local_dir_mode else None,
|
|
||||||
)
|
|
||||||
except requests.exceptions.HTTPError as exc_info:
|
except requests.exceptions.HTTPError as exc_info:
|
||||||
if exc_info.response.status_code == 401:
|
if exc_info.response.status_code == 401:
|
||||||
raise GatedRepoError(f"{known_file.repo_id}/{known_file.filename}", response=exc_info.response)
|
raise GatedRepoError(f"{known_file.repo_id}/{known_file.filename}", response=exc_info.response)
|
||||||
except IOError as exc_info:
|
except IOError as exc_info:
|
||||||
logger.error(f"cannot reach huggingface {known_file.repo_id}/{known_file.filename}", exc_info=exc_info)
|
logger.error(f"cannot reach huggingface {known_file.repo_id}/{known_file.filename}", exc_info=exc_info)
|
||||||
except Exception as exc_info:
|
except Exception as exc_info:
|
||||||
logger.error(f"an exception occurred while downloading {known_file.repo_id}/{known_file.filename}", exc_info=exc_info)
|
logger.error(f"an exception occurred while downloading {known_file.repo_id}/{known_file.filename}. hf_hub_download kwargs={hf_hub_download_kwargs}", exc_info=exc_info)
|
||||||
dump_environment_info()
|
dump_environment_info()
|
||||||
for key, value in os.environ.items():
|
for key, value in os.environ.items():
|
||||||
if key.startswith("HF_XET"):
|
if key.startswith("HF_"):
|
||||||
|
if key == "HF_TOKEN":
|
||||||
|
value = "*****"
|
||||||
print(f"{key}={value}", file=sys.stderr)
|
print(f"{key}={value}", file=sys.stderr)
|
||||||
|
|
||||||
if path is not None and known_file.convert_to_16_bit and file_size is not None and file_size != 0:
|
if path is not None and known_file.convert_to_16_bit and file_size is not None and file_size != 0:
|
||||||
|
|||||||
@ -1,21 +1,18 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import collections
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import functools
|
import functools
|
||||||
from os.path import split
|
from os.path import split
|
||||||
from pathlib import PurePosixPath
|
from pathlib import PurePosixPath
|
||||||
from typing import Optional, List, Sequence, Union, Iterable, Protocol
|
from typing import Optional, List, Sequence, Union, Iterable
|
||||||
|
|
||||||
from can_ada import parse, URL # pylint: disable=no-name-in-module
|
from can_ada import parse, URL # pylint: disable=no-name-in-module
|
||||||
from typing_extensions import TypedDict, NotRequired, runtime_checkable
|
from typing_extensions import TypedDict, NotRequired
|
||||||
|
|
||||||
from .component_model.executor_types import ComboOptions
|
from .component_model.executor_types import ComboOptions
|
||||||
from .component_model.files import canonicalize_path
|
from .component_model.files import canonicalize_path
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class UrlFile:
|
class UrlFile:
|
||||||
_url: str
|
_url: str
|
||||||
|
|||||||
@ -116,7 +116,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
model_options = {}
|
model_options = {}
|
||||||
if special_tokens is None:
|
if special_tokens is None:
|
||||||
special_tokens = {"start": 49406, "end": 49407, "pad": 49407}
|
special_tokens = {"start": 49406, "end": 49407, "pad": 49407}
|
||||||
assert layer in self.LAYERS
|
|
||||||
|
|
||||||
if textmodel_json_config is None and "model_name" not in model_options:
|
if textmodel_json_config is None and "model_name" not in model_options:
|
||||||
model_options = {**model_options, "model_name": "clip_l"}
|
model_options = {**model_options, "model_name": "clip_l"}
|
||||||
|
|||||||
@ -1539,6 +1539,6 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
|
|||||||
if quant_metadata is not None:
|
if quant_metadata is not None:
|
||||||
layers = quant_metadata["layers"]
|
layers = quant_metadata["layers"]
|
||||||
for k, v in layers.items():
|
for k, v in layers.items():
|
||||||
state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(json.dumps(v).encode('utf-8'), dtype=torch.uint8)
|
state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(bytearray(json.dumps(v).encode('utf-8')), dtype=torch.uint8)
|
||||||
|
|
||||||
return state_dict, metadata
|
return state_dict, metadata
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import importlib.resources
|
import importlib.resources
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from importlib.abc import Traversable
|
from importlib.abc import Traversable
|
||||||
from typing import Any, AsyncGenerator
|
from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
@ -45,7 +46,7 @@ def _generate_config_params():
|
|||||||
{"fast": set()},
|
{"fast": set()},
|
||||||
# {"fast": {PerformanceFeature.Fp16Accumulation}},
|
# {"fast": {PerformanceFeature.Fp16Accumulation}},
|
||||||
# {"fast": {PerformanceFeature.Fp8MatrixMultiplication}},
|
# {"fast": {PerformanceFeature.Fp8MatrixMultiplication}},
|
||||||
# {"fast": {PerformanceFeature.CublasOps}},
|
{"fast": {PerformanceFeature.CublasOps}},
|
||||||
]
|
]
|
||||||
|
|
||||||
for attn, asnc, pinned, fst in itertools.product(attn_options, async_options, pinned_options, fast_options):
|
for attn, asnc, pinned, fst in itertools.product(attn_options, async_options, pinned_options, fast_options):
|
||||||
@ -57,7 +58,7 @@ def _generate_config_params():
|
|||||||
yield config_update
|
yield config_update
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=False, params=_generate_config_params())
|
@pytest.fixture(scope="function", autouse=False, params=_generate_config_params(), ids=lambda p: ",".join(f"{k}={v}" for k, v in p.items()))
|
||||||
async def client(tmp_path_factory, request) -> AsyncGenerator[Any, Any]:
|
async def client(tmp_path_factory, request) -> AsyncGenerator[Any, Any]:
|
||||||
config = default_configuration()
|
config = default_configuration()
|
||||||
# this should help things go a little faster
|
# this should help things go a little faster
|
||||||
@ -91,10 +92,15 @@ async def test_workflow(workflow_name: str, workflow_file: Traversable, has_gpu:
|
|||||||
prompt = Prompt.validate(workflow)
|
prompt = Prompt.validate(workflow)
|
||||||
# todo: add all the models we want to test a bit m2ore elegantly
|
# todo: add all the models we want to test a bit m2ore elegantly
|
||||||
outputs = {}
|
outputs = {}
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
try:
|
try:
|
||||||
outputs = await client.queue_prompt(prompt)
|
outputs = await client.queue_prompt(prompt)
|
||||||
except TorchAudioNotFoundError:
|
except TorchAudioNotFoundError:
|
||||||
pytest.skip("requires torchaudio")
|
pytest.skip("requires torchaudio")
|
||||||
|
finally:
|
||||||
|
end_time = time.time()
|
||||||
|
logger.info(f"Test {workflow_name} with client {client} took {end_time - start_time:.4f}s")
|
||||||
|
|
||||||
if any(v.class_type == "SaveImage" for v in prompt.values()):
|
if any(v.class_type == "SaveImage" for v in prompt.values()):
|
||||||
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage")
|
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage")
|
||||||
|
|||||||
@ -33,7 +33,7 @@
|
|||||||
"4": {
|
"4": {
|
||||||
"inputs": {
|
"inputs": {
|
||||||
"prompt": "Who is this?",
|
"prompt": "Who is this?",
|
||||||
"chat_template": "llava-v1.6-mistral-7b-hf",
|
"chat_template": "default",
|
||||||
"model": [
|
"model": [
|
||||||
"1",
|
"1",
|
||||||
0
|
0
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user