improve tests

This commit is contained in:
doctorpangloss 2025-12-11 17:03:39 -08:00
parent a2e898f091
commit b81a5b15ae
9 changed files with 62 additions and 53 deletions

View File

@ -23,6 +23,7 @@ from .async_progress_iterable import QueuePromptWithProgress
from .client_types import V1QueuePromptResponse
from ..api.components.schema.prompt import PromptDict
from ..cli_args_types import Configuration
from ..cli_args import default_configuration
from ..cmd.folder_paths import init_default_paths # pylint: disable=import-error
from ..component_model.executor_types import ExecutorToClientProgress
from ..component_model.make_mutable import make_mutable
@ -176,7 +177,7 @@ class Comfy:
In order to use this in blocking methods, learn more about asyncio online.
"""
def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: ProcessPoolExecutor | ContextVarExecutor | Literal["ProcessPoolExecutor","ContextVarExecutor"] = None):
def __init__(self, configuration: Optional[Configuration] = None, progress_handler: Optional[ExecutorToClientProgress] = None, max_workers: int = 1, executor: ProcessPoolExecutor | ContextVarExecutor | Literal["ProcessPoolExecutor", "ContextVarExecutor"] = None):
self._progress_handler = progress_handler or ServerStub()
self._owns_executor = executor is None or isinstance(executor, str)
if self._owns_executor:
@ -188,6 +189,7 @@ class Comfy:
else:
assert not isinstance(executor, str)
self._executor = executor
self._default_configuration = default_configuration()
self._configuration = configuration
self._is_running = False
self._task_count_lock = RLock()
@ -332,5 +334,9 @@ class Comfy:
with self._task_count_lock:
self._task_count -= 1
def __str__(self):
diff = {k: v for k, v in (self._configuration or {}).items() if v != self._default_configuration.get(k)}
return f"<Comfy task_count={self.task_count} configuration={diff} executor={self._executor}>"
EmbeddedComfyClient = Comfy

View File

@ -1,9 +1,11 @@
# Original code can be found on: https://github.com/black-forest-labs/flux
import torch
from dataclasses import dataclass
from einops import rearrange, repeat
import torch
from torch import Tensor, nn
from einops import rearrange, repeat
from ..common_dit import pad_to_patch_size
from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP
from .layers import (
@ -16,7 +18,6 @@ from .layers import (
Modulation,
RMSNorm
)
from .. import common_dit
@dataclass
@ -52,8 +53,6 @@ class Flux(nn.Module):
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
# todo: should this be here?
self.device = device
self.dtype = dtype
params = FluxParams(**kwargs)
self.params = params
@ -147,9 +146,6 @@ class Flux(nn.Module):
if transformer_options is None:
transformer_options = {}
if y is None:
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
patches = transformer_options.get("patches", {})
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
@ -273,10 +269,12 @@ class Flux(nn.Module):
img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
return img
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options=None):
if transformer_options is None:
transformer_options = {}
bs, c, h, w = x.shape
patch_size = self.patch_size
x = common_dit.pad_to_patch_size(x, (patch_size, patch_size))
x = pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)

View File

@ -794,7 +794,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
quant_config = detect_layer_quantization(state_dict, unet_key_prefix)
if quant_config:
model_config.quant_config = quant_config
logger.info("Detected mixed precision quantization")
logger.debug("Detected mixed precision quantization")
if metadata is not None and "format" in metadata and metadata["format"] == "gguf":
model_config.custom_operations = GGMLOps

View File

@ -14,6 +14,7 @@ from pathlib import Path
from typing import List, Optional, Final, Set
import requests
import requests_cache
import tqdm
from huggingface_hub import dump_environment_info, hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem, CacheNotFound
from huggingface_hub.utils import GatedRepoError, LocalEntryNotFoundError
@ -137,37 +138,39 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
path = None
cache_hit = False
try:
# always retrieve this from the cache if it already exists there
path = hf_hub_download(repo_id=known_file.repo_id,
filename=known_file.filename,
repo_type=known_file.repo_type,
revision=known_file.revision,
local_files_only=True,
local_dir=hf_destination_dir if args.force_hf_local_dir_mode else None,
)
logger.debug(f"hf_hub_download cache hit for {known_file.repo_id}/{known_file.filename}")
cache_hit = True
except LocalEntryNotFoundError:
hf_hub_download_kwargs = dict(repo_id=known_file.repo_id,
filename=known_file.filename,
repo_type=known_file.repo_type,
revision=known_file.revision,
local_files_only=True,
local_dir=hf_destination_dir if args.force_hf_local_dir_mode else None,
token=True,
)
with requests_cache.disabled():
try:
logger.debug(f"{folder_name}/{filename} is being downloaded from {known_file.repo_id}/{known_file.filename} candidate_str_match={candidate_str_match} candidate_filename_match={candidate_filename_match} candidate_alternate_filenames_match={candidate_alternate_filenames_match} candidate_save_filename_match={candidate_save_filename_match}")
path = hf_hub_download(repo_id=known_file.repo_id,
filename=known_file.filename,
repo_type=known_file.repo_type,
revision=known_file.revision,
local_dir=hf_destination_dir if args.force_hf_local_dir_mode else None,
)
except requests.exceptions.HTTPError as exc_info:
if exc_info.response.status_code == 401:
raise GatedRepoError(f"{known_file.repo_id}/{known_file.filename}", response=exc_info.response)
except IOError as exc_info:
logger.error(f"cannot reach huggingface {known_file.repo_id}/{known_file.filename}", exc_info=exc_info)
except Exception as exc_info:
logger.error(f"an exception occurred while downloading {known_file.repo_id}/{known_file.filename}", exc_info=exc_info)
dump_environment_info()
for key, value in os.environ.items():
if key.startswith("HF_XET"):
print(f"{key}={value}", file=sys.stderr)
# always retrieve this from the cache if it already exists there
path = hf_hub_download(**hf_hub_download_kwargs)
logger.debug(f"hf_hub_download cache hit for {known_file.repo_id}/{known_file.filename}")
cache_hit = True
except LocalEntryNotFoundError:
try:
logger.debug(f"{folder_name}/{filename} is being downloaded from {known_file.repo_id}/{known_file.filename} candidate_str_match={candidate_str_match} candidate_filename_match={candidate_filename_match} candidate_alternate_filenames_match={candidate_alternate_filenames_match} candidate_save_filename_match={candidate_save_filename_match}")
hf_hub_download_kwargs.pop("local_files_only")
path = hf_hub_download(**hf_hub_download_kwargs)
except requests.exceptions.HTTPError as exc_info:
if exc_info.response.status_code == 401:
raise GatedRepoError(f"{known_file.repo_id}/{known_file.filename}", response=exc_info.response)
except IOError as exc_info:
logger.error(f"cannot reach huggingface {known_file.repo_id}/{known_file.filename}", exc_info=exc_info)
except Exception as exc_info:
logger.error(f"an exception occurred while downloading {known_file.repo_id}/{known_file.filename}. hf_hub_download kwargs={hf_hub_download_kwargs}", exc_info=exc_info)
dump_environment_info()
for key, value in os.environ.items():
if key.startswith("HF_"):
if key == "HF_TOKEN":
value = "*****"
print(f"{key}={value}", file=sys.stderr)
if path is not None and known_file.convert_to_16_bit and file_size is not None and file_size != 0:
tensors = {}

View File

@ -1,21 +1,18 @@
from __future__ import annotations
import collections
import dataclasses
import functools
from os.path import split
from pathlib import PurePosixPath
from typing import Optional, List, Sequence, Union, Iterable, Protocol
from typing import Optional, List, Sequence, Union, Iterable
from can_ada import parse, URL # pylint: disable=no-name-in-module
from typing_extensions import TypedDict, NotRequired, runtime_checkable
from typing_extensions import TypedDict, NotRequired
from .component_model.executor_types import ComboOptions
from .component_model.files import canonicalize_path
@dataclasses.dataclass(frozen=True)
class UrlFile:
_url: str

View File

@ -116,7 +116,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
model_options = {}
if special_tokens is None:
special_tokens = {"start": 49406, "end": 49407, "pad": 49407}
assert layer in self.LAYERS
if textmodel_json_config is None and "model_name" not in model_options:
model_options = {**model_options, "model_name": "clip_l"}

View File

@ -1539,6 +1539,6 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
if quant_metadata is not None:
layers = quant_metadata["layers"]
for k, v in layers.items():
state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(json.dumps(v).encode('utf-8'), dtype=torch.uint8)
state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(bytearray(json.dumps(v).encode('utf-8')), dtype=torch.uint8)
return state_dict, metadata

View File

@ -1,6 +1,7 @@
import importlib.resources
import json
import logging
import time
from importlib.abc import Traversable
from typing import Any, AsyncGenerator
@ -45,7 +46,7 @@ def _generate_config_params():
{"fast": set()},
# {"fast": {PerformanceFeature.Fp16Accumulation}},
# {"fast": {PerformanceFeature.Fp8MatrixMultiplication}},
# {"fast": {PerformanceFeature.CublasOps}},
{"fast": {PerformanceFeature.CublasOps}},
]
for attn, asnc, pinned, fst in itertools.product(attn_options, async_options, pinned_options, fast_options):
@ -57,7 +58,7 @@ def _generate_config_params():
yield config_update
@pytest.fixture(scope="function", autouse=False, params=_generate_config_params())
@pytest.fixture(scope="function", autouse=False, params=_generate_config_params(), ids=lambda p: ",".join(f"{k}={v}" for k, v in p.items()))
async def client(tmp_path_factory, request) -> AsyncGenerator[Any, Any]:
config = default_configuration()
# this should help things go a little faster
@ -91,10 +92,15 @@ async def test_workflow(workflow_name: str, workflow_file: Traversable, has_gpu:
prompt = Prompt.validate(workflow)
# todo: add all the models we want to test a bit m2ore elegantly
outputs = {}
start_time = time.time()
try:
outputs = await client.queue_prompt(prompt)
except TorchAudioNotFoundError:
pytest.skip("requires torchaudio")
finally:
end_time = time.time()
logger.info(f"Test {workflow_name} with client {client} took {end_time - start_time:.4f}s")
if any(v.class_type == "SaveImage" for v in prompt.values()):
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage")

View File

@ -33,7 +33,7 @@
"4": {
"inputs": {
"prompt": "Who is this?",
"chat_template": "llava-v1.6-mistral-7b-hf",
"chat_template": "default",
"model": [
"1",
0