improve tests

This commit is contained in:
doctorpangloss 2025-12-11 17:03:39 -08:00
parent a2e898f091
commit b81a5b15ae
9 changed files with 62 additions and 53 deletions

View File

@ -23,6 +23,7 @@ from .async_progress_iterable import QueuePromptWithProgress
from .client_types import V1QueuePromptResponse from .client_types import V1QueuePromptResponse
from ..api.components.schema.prompt import PromptDict from ..api.components.schema.prompt import PromptDict
from ..cli_args_types import Configuration from ..cli_args_types import Configuration
from ..cli_args import default_configuration
from ..cmd.folder_paths import init_default_paths # pylint: disable=import-error from ..cmd.folder_paths import init_default_paths # pylint: disable=import-error
from ..component_model.executor_types import ExecutorToClientProgress from ..component_model.executor_types import ExecutorToClientProgress
from ..component_model.make_mutable import make_mutable from ..component_model.make_mutable import make_mutable
@ -188,6 +189,7 @@ class Comfy:
else: else:
assert not isinstance(executor, str) assert not isinstance(executor, str)
self._executor = executor self._executor = executor
self._default_configuration = default_configuration()
self._configuration = configuration self._configuration = configuration
self._is_running = False self._is_running = False
self._task_count_lock = RLock() self._task_count_lock = RLock()
@ -332,5 +334,9 @@ class Comfy:
with self._task_count_lock: with self._task_count_lock:
self._task_count -= 1 self._task_count -= 1
def __str__(self):
diff = {k: v for k, v in (self._configuration or {}).items() if v != self._default_configuration.get(k)}
return f"<Comfy task_count={self.task_count} configuration={diff} executor={self._executor}>"
EmbeddedComfyClient = Comfy EmbeddedComfyClient = Comfy

View File

@ -1,9 +1,11 @@
# Original code can be found on: https://github.com/black-forest-labs/flux # Original code can be found on: https://github.com/black-forest-labs/flux
import torch
from dataclasses import dataclass from dataclasses import dataclass
from einops import rearrange, repeat
import torch
from torch import Tensor, nn from torch import Tensor, nn
from einops import rearrange, repeat
from ..common_dit import pad_to_patch_size
from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP
from .layers import ( from .layers import (
@ -16,7 +18,6 @@ from .layers import (
Modulation, Modulation,
RMSNorm RMSNorm
) )
from .. import common_dit
@dataclass @dataclass
@ -52,8 +53,6 @@ class Flux(nn.Module):
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs): def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
super().__init__() super().__init__()
# todo: should this be here?
self.device = device
self.dtype = dtype self.dtype = dtype
params = FluxParams(**kwargs) params = FluxParams(**kwargs)
self.params = params self.params = params
@ -147,9 +146,6 @@ class Flux(nn.Module):
if transformer_options is None: if transformer_options is None:
transformer_options = {} transformer_options = {}
if y is None:
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
patches = transformer_options.get("patches", {}) patches = transformer_options.get("patches", {})
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3: if img.ndim != 3 or txt.ndim != 3:
@ -273,10 +269,12 @@ class Flux(nn.Module):
img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels) img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
return img return img
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}): def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options=None):
if transformer_options is None:
transformer_options = {}
bs, c, h, w = x.shape bs, c, h, w = x.shape
patch_size = self.patch_size patch_size = self.patch_size
x = common_dit.pad_to_patch_size(x, (patch_size, patch_size)) x = pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size) h_len = ((h + (patch_size // 2)) // patch_size)

View File

@ -794,7 +794,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
quant_config = detect_layer_quantization(state_dict, unet_key_prefix) quant_config = detect_layer_quantization(state_dict, unet_key_prefix)
if quant_config: if quant_config:
model_config.quant_config = quant_config model_config.quant_config = quant_config
logger.info("Detected mixed precision quantization") logger.debug("Detected mixed precision quantization")
if metadata is not None and "format" in metadata and metadata["format"] == "gguf": if metadata is not None and "format" in metadata and metadata["format"] == "gguf":
model_config.custom_operations = GGMLOps model_config.custom_operations = GGMLOps

View File

@ -14,6 +14,7 @@ from pathlib import Path
from typing import List, Optional, Final, Set from typing import List, Optional, Final, Set
import requests import requests
import requests_cache
import tqdm import tqdm
from huggingface_hub import dump_environment_info, hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem, CacheNotFound from huggingface_hub import dump_environment_info, hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem, CacheNotFound
from huggingface_hub.utils import GatedRepoError, LocalEntryNotFoundError from huggingface_hub.utils import GatedRepoError, LocalEntryNotFoundError
@ -137,36 +138,38 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
path = None path = None
cache_hit = False cache_hit = False
try: hf_hub_download_kwargs = dict(repo_id=known_file.repo_id,
# always retrieve this from the cache if it already exists there
path = hf_hub_download(repo_id=known_file.repo_id,
filename=known_file.filename, filename=known_file.filename,
repo_type=known_file.repo_type, repo_type=known_file.repo_type,
revision=known_file.revision, revision=known_file.revision,
local_files_only=True, local_files_only=True,
local_dir=hf_destination_dir if args.force_hf_local_dir_mode else None, local_dir=hf_destination_dir if args.force_hf_local_dir_mode else None,
token=True,
) )
with requests_cache.disabled():
try:
# always retrieve this from the cache if it already exists there
path = hf_hub_download(**hf_hub_download_kwargs)
logger.debug(f"hf_hub_download cache hit for {known_file.repo_id}/{known_file.filename}") logger.debug(f"hf_hub_download cache hit for {known_file.repo_id}/{known_file.filename}")
cache_hit = True cache_hit = True
except LocalEntryNotFoundError: except LocalEntryNotFoundError:
try: try:
logger.debug(f"{folder_name}/{filename} is being downloaded from {known_file.repo_id}/{known_file.filename} candidate_str_match={candidate_str_match} candidate_filename_match={candidate_filename_match} candidate_alternate_filenames_match={candidate_alternate_filenames_match} candidate_save_filename_match={candidate_save_filename_match}") logger.debug(f"{folder_name}/{filename} is being downloaded from {known_file.repo_id}/{known_file.filename} candidate_str_match={candidate_str_match} candidate_filename_match={candidate_filename_match} candidate_alternate_filenames_match={candidate_alternate_filenames_match} candidate_save_filename_match={candidate_save_filename_match}")
path = hf_hub_download(repo_id=known_file.repo_id, hf_hub_download_kwargs.pop("local_files_only")
filename=known_file.filename, path = hf_hub_download(**hf_hub_download_kwargs)
repo_type=known_file.repo_type,
revision=known_file.revision,
local_dir=hf_destination_dir if args.force_hf_local_dir_mode else None,
)
except requests.exceptions.HTTPError as exc_info: except requests.exceptions.HTTPError as exc_info:
if exc_info.response.status_code == 401: if exc_info.response.status_code == 401:
raise GatedRepoError(f"{known_file.repo_id}/{known_file.filename}", response=exc_info.response) raise GatedRepoError(f"{known_file.repo_id}/{known_file.filename}", response=exc_info.response)
except IOError as exc_info: except IOError as exc_info:
logger.error(f"cannot reach huggingface {known_file.repo_id}/{known_file.filename}", exc_info=exc_info) logger.error(f"cannot reach huggingface {known_file.repo_id}/{known_file.filename}", exc_info=exc_info)
except Exception as exc_info: except Exception as exc_info:
logger.error(f"an exception occurred while downloading {known_file.repo_id}/{known_file.filename}", exc_info=exc_info) logger.error(f"an exception occurred while downloading {known_file.repo_id}/{known_file.filename}. hf_hub_download kwargs={hf_hub_download_kwargs}", exc_info=exc_info)
dump_environment_info() dump_environment_info()
for key, value in os.environ.items(): for key, value in os.environ.items():
if key.startswith("HF_XET"): if key.startswith("HF_"):
if key == "HF_TOKEN":
value = "*****"
print(f"{key}={value}", file=sys.stderr) print(f"{key}={value}", file=sys.stderr)
if path is not None and known_file.convert_to_16_bit and file_size is not None and file_size != 0: if path is not None and known_file.convert_to_16_bit and file_size is not None and file_size != 0:

View File

@ -1,21 +1,18 @@
from __future__ import annotations from __future__ import annotations
import collections
import dataclasses import dataclasses
import functools import functools
from os.path import split from os.path import split
from pathlib import PurePosixPath from pathlib import PurePosixPath
from typing import Optional, List, Sequence, Union, Iterable, Protocol from typing import Optional, List, Sequence, Union, Iterable
from can_ada import parse, URL # pylint: disable=no-name-in-module from can_ada import parse, URL # pylint: disable=no-name-in-module
from typing_extensions import TypedDict, NotRequired, runtime_checkable from typing_extensions import TypedDict, NotRequired
from .component_model.executor_types import ComboOptions from .component_model.executor_types import ComboOptions
from .component_model.files import canonicalize_path from .component_model.files import canonicalize_path
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class UrlFile: class UrlFile:
_url: str _url: str

View File

@ -116,7 +116,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
model_options = {} model_options = {}
if special_tokens is None: if special_tokens is None:
special_tokens = {"start": 49406, "end": 49407, "pad": 49407} special_tokens = {"start": 49406, "end": 49407, "pad": 49407}
assert layer in self.LAYERS
if textmodel_json_config is None and "model_name" not in model_options: if textmodel_json_config is None and "model_name" not in model_options:
model_options = {**model_options, "model_name": "clip_l"} model_options = {**model_options, "model_name": "clip_l"}

View File

@ -1539,6 +1539,6 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
if quant_metadata is not None: if quant_metadata is not None:
layers = quant_metadata["layers"] layers = quant_metadata["layers"]
for k, v in layers.items(): for k, v in layers.items():
state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(json.dumps(v).encode('utf-8'), dtype=torch.uint8) state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(bytearray(json.dumps(v).encode('utf-8')), dtype=torch.uint8)
return state_dict, metadata return state_dict, metadata

View File

@ -1,6 +1,7 @@
import importlib.resources import importlib.resources
import json import json
import logging import logging
import time
from importlib.abc import Traversable from importlib.abc import Traversable
from typing import Any, AsyncGenerator from typing import Any, AsyncGenerator
@ -45,7 +46,7 @@ def _generate_config_params():
{"fast": set()}, {"fast": set()},
# {"fast": {PerformanceFeature.Fp16Accumulation}}, # {"fast": {PerformanceFeature.Fp16Accumulation}},
# {"fast": {PerformanceFeature.Fp8MatrixMultiplication}}, # {"fast": {PerformanceFeature.Fp8MatrixMultiplication}},
# {"fast": {PerformanceFeature.CublasOps}}, {"fast": {PerformanceFeature.CublasOps}},
] ]
for attn, asnc, pinned, fst in itertools.product(attn_options, async_options, pinned_options, fast_options): for attn, asnc, pinned, fst in itertools.product(attn_options, async_options, pinned_options, fast_options):
@ -57,7 +58,7 @@ def _generate_config_params():
yield config_update yield config_update
@pytest.fixture(scope="function", autouse=False, params=_generate_config_params()) @pytest.fixture(scope="function", autouse=False, params=_generate_config_params(), ids=lambda p: ",".join(f"{k}={v}" for k, v in p.items()))
async def client(tmp_path_factory, request) -> AsyncGenerator[Any, Any]: async def client(tmp_path_factory, request) -> AsyncGenerator[Any, Any]:
config = default_configuration() config = default_configuration()
# this should help things go a little faster # this should help things go a little faster
@ -91,10 +92,15 @@ async def test_workflow(workflow_name: str, workflow_file: Traversable, has_gpu:
prompt = Prompt.validate(workflow) prompt = Prompt.validate(workflow)
# todo: add all the models we want to test a bit m2ore elegantly # todo: add all the models we want to test a bit m2ore elegantly
outputs = {} outputs = {}
start_time = time.time()
try: try:
outputs = await client.queue_prompt(prompt) outputs = await client.queue_prompt(prompt)
except TorchAudioNotFoundError: except TorchAudioNotFoundError:
pytest.skip("requires torchaudio") pytest.skip("requires torchaudio")
finally:
end_time = time.time()
logger.info(f"Test {workflow_name} with client {client} took {end_time - start_time:.4f}s")
if any(v.class_type == "SaveImage" for v in prompt.values()): if any(v.class_type == "SaveImage" for v in prompt.values()):
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage") save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage")

View File

@ -33,7 +33,7 @@
"4": { "4": {
"inputs": { "inputs": {
"prompt": "Who is this?", "prompt": "Who is this?",
"chat_template": "llava-v1.6-mistral-7b-hf", "chat_template": "default",
"model": [ "model": [
"1", "1",
0 0