Qwen Image with sage attention workarounds

This commit is contained in:
doctorpangloss 2025-08-07 17:29:23 -07:00
parent b72e5ff448
commit cd17b42664
18 changed files with 435 additions and 44 deletions

View File

@ -18,7 +18,6 @@ from .. import options
from ..app import logger
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
os.environ["HF_XET_HIGH_PERFORMANCE"] = "True"
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1"
os.environ["BITSANDBYTES_NOWELCOME"] = "1"

View File

@ -0,0 +1,119 @@
from __future__ import annotations
import logging
import os
import time
from concurrent.futures import Future
from pathlib import Path
from typing import Optional
import filelock
import huggingface_hub
from huggingface_hub import hf_hub_download
from huggingface_hub import logging as hf_logging
hf_logging.set_verbosity_debug()
from pebble import ThreadPool
from .tqdm_watcher import TqdmWatcher
_VAR = "HF_HUB_ENABLE_HF_TRANSFER"
_XET_VAR = "HF_XET_HIGH_PERFORMANCE"
os.environ[_VAR] = "True"
os.environ["HF_HUB_DISABLE_XET"] = "1"
# os.environ["HF_XET_HIGH_PERFORMANCE"] = "True"
logger = logging.getLogger(__name__)
logger.debug("Xet was disabled since it is currently not reliable")
def hf_hub_download_with_disable_fast(repo_id=None, filename=None, disable_fast=None, hf_env: dict[str, str] = None, **kwargs):
for k, v in hf_env.items():
os.environ[k] = v
if disable_fast:
if _VAR == _XET_VAR:
os.environ["HF_HUB_DISABLE_XET"] = "1"
else:
os.environ[_VAR] = "False"
return hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
def hf_hub_download_with_retries(repo_id: str, filename: str, watcher: Optional[TqdmWatcher] = None, retries=2, stall_timeout=10, **kwargs):
"""
Wraps hf_hub_download with stall detection and retries using a TqdmWatcher.
Includes a monkey-patch for filelock to release locks from stalled downloads.
"""
if watcher is None:
logger.warning(f"called _hf_hub_download_with_retries without progress to watch")
return hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
xet_available = huggingface_hub.file_download.is_xet_available()
hf_hub_disable_xet_prev_value = os.getenv("HF_HUB_DISABLE_XET")
disable_fast = hf_hub_disable_xet_prev_value is not None
instantiated_locks: list[filelock.FileLock] = []
original_filelock_init = filelock.FileLock.__init__
def new_filelock_init(self, *args, **kwargs):
"""A wrapper around FileLock.__init__ to capture lock instances."""
original_filelock_init(self, *args, **kwargs)
instantiated_locks.append(self)
filelock.FileLock.__init__ = new_filelock_init
try:
with ThreadPool(max_workers=retries + 1) as executor:
for attempt in range(retries):
watcher.tick()
hf_env = {k: v for k, v in os.environ.items() if k.upper().startswith("HF_")}
if len(instantiated_locks) > 0:
logger.debug(f"Attempting to unlock {len(instantiated_locks)} captured file locks.")
for lock in instantiated_locks:
path = lock.lock_file
if lock.is_locked:
lock.release(force=True)
else:
# something else went wrong
try:
lock._release()
except (AttributeError, TypeError):
pass
try:
Path(path).unlink(missing_ok=True)
except OSError:
# todo: obviously the process is holding this lock
pass
logger.debug(f"Released stalled lock: {lock.lock_file}")
instantiated_locks.clear()
future: Future[str] = executor.submit(hf_hub_download_with_disable_fast, repo_id=repo_id, filename=filename, disable_fast=disable_fast, hf_env=hf_env, **kwargs)
try:
while not future.done():
if time.monotonic() - watcher.last_update_time > stall_timeout:
msg = f"Download of '{repo_id}/{filename}' stalled for >{stall_timeout}s. Retrying... (Attempt {attempt + 1}/{retries})"
if xet_available:
logger.warning(f"{msg}. Disabling xet for our retry.")
disable_fast = True
else:
logger.warning(msg)
future.cancel() # Cancel the stalled future
break
time.sleep(0.5)
if future.done() and not future.cancelled():
return future.result()
except Exception as e:
logger.error(f"Exception during download attempt {attempt + 1}: {e}", exc_info=True)
raise RuntimeError(f"Failed to download '{repo_id}/{filename}' after {retries} attempts.")
finally:
filelock.FileLock.__init__ = original_filelock_init
if hf_hub_disable_xet_prev_value is not None:
os.environ["HF_HUB_DISABLE_XET"] = hf_hub_disable_xet_prev_value

View File

@ -0,0 +1,20 @@
from __future__ import annotations
import time
class TqdmWatcher:
"""An object to track the progress of a tqdm instance."""
def __init__(self):
# We use a list to store the time, making it mutable across scopes.
self._last_update = [time.monotonic()]
def tick(self):
"""Signals that progress has been made by updating the timestamp."""
self._last_update[0] = time.monotonic()
@property
def last_update_time(self) -> float:
"""Gets the time of the last recorded progress update."""
return self._last_update[0]

View File

@ -36,8 +36,8 @@ REARRANGE_THRESHOLD = 512
MAX_TENSOR_NAME_LENGTH = 127
MAX_TENSOR_DIMS = 4
TORCH_COMPATIBLE_QTYPES = (None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16)
IMG_ARCH_LIST = {"flux", "sd1", "sdxl", "sd3", "aura", "hidream", "cosmos", "ltxv", "hyvid", "wan"}
TXT_ARCH_LIST = {"t5", "t5encoder", "llama"}
IMG_ARCH_LIST = {"flux", "sd1", "sdxl", "sd3", "aura", "hidream", "cosmos", "ltxv", "hyvid", "wan", "qwen_image"}
TXT_ARCH_LIST = {"t5", "t5encoder", "llama", "qwen2vl"}
class ModelTemplate:
@ -739,9 +739,9 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal
except Exception as e:
raise ValueError(f"This model is not currently supported - ({e})")
elif arch_str not in TXT_ARCH_LIST and is_text_model:
raise ValueError(f"Unexpected text model architecture type in GGUF file: {arch_str!r}")
logger.warning(f"Unexpected text model architecture type in GGUF file: {arch_str!r}")
elif arch_str not in IMG_ARCH_LIST and not is_text_model:
raise ValueError(f"Unexpected architecture type in GGUF file: {arch_str!r}")
logger.warning(f"Unexpected architecture type in GGUF file: {arch_str!r}")
if compat:
logger.warning(f"Warning: This gguf model file is loaded in compatibility mode '{compat}' [arch:{arch_str}]")
@ -903,7 +903,7 @@ def gguf_clip_loader(path):
logger.warning(f"Dequantizing {temb_key} to prevent runtime OOM.")
sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16)
sd = sd_map_replace(sd, T5_SD_MAP)
elif arch in {"llama"}:
elif arch in {"llama", "qwen2vl"}:
# TODO: pass model_options["vocab_size"] to loader somehow
temb_key = "token_embd.weight"
if temb_key in sd and sd[temb_key].shape[0] >= (64 * 1024):
@ -911,7 +911,8 @@ def gguf_clip_loader(path):
logger.warning(f"Dequantizing {temb_key} to prevent runtime OOM.")
sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16)
sd = sd_map_replace(sd, LLAMA_SD_MAP)
sd = llama_permute(sd, 32, 8) # L3
if arch == "llama":
sd = llama_permute(sd, 32, 8) # L3
else:
pass
return sd

View File

@ -605,11 +605,13 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
return out
optimized_attention = attention_basic
optimized_attention = attention_pytorch
optimized_attention_no_sage = attention_pytorch
if model_management.sage_attention_enabled():
logger.debug("Using sage attention")
optimized_attention = attention_sage
optimized_attention_no_sage = attention_pytorch
elif model_management.xformers_enabled():
logger.debug("Using xformers attention")
optimized_attention = attention_xformers
@ -628,6 +630,7 @@ else:
optimized_attention = attention_sub_quad
optimized_attention_masked = optimized_attention
optimized_attention_no_sage_masked = optimized_attention_no_sage
def optimized_attention_for_device(device, mask=False, small_input=False):

View File

@ -8,7 +8,7 @@ from typing import Optional, Tuple
from ..common_dit import pad_to_patch_size
from ..flux.layers import EmbedND
from ..lightricks.model import TimestepEmbedding, Timesteps
from ..modules.attention import optimized_attention_masked
from ..modules.attention import optimized_attention_no_sage_masked as optimized_attention_masked
class GELU(nn.Module):

View File

@ -5,6 +5,7 @@ import logging
import operator
import os
import shutil
import sys
from collections.abc import Sequence, MutableSequence
from functools import reduce
from itertools import chain
@ -12,16 +13,15 @@ from os.path import join
from pathlib import Path
from typing import List, Optional, Final, Set
# enable better transfer
os.environ["HF_XET_HIGH_PERFORMANCE"] = "True"
from .component_model.hf_hub_download_with_disable_xet import hf_hub_download_with_retries
import tqdm
from huggingface_hub import hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem
from huggingface_hub.file_download import are_symlinks_supported
from huggingface_hub.utils import GatedRepoError, LocalEntryNotFoundError
from requests import Session
from safetensors import safe_open
from safetensors.torch import save_file
from huggingface_hub import dump_environment_info
from .cli_args import args
from .cmd import folder_paths
@ -92,7 +92,7 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
if known_file is None:
logger.debug(f"get_or_download could not find {filename} in {folder_name}, known_files={known_files}")
return path
with comfy_tqdm():
with comfy_tqdm() as watcher:
if isinstance(known_file, HuggingFile):
if known_file.save_with_filename is not None:
linked_filename = known_file.save_with_filename
@ -139,14 +139,21 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
except LocalEntryNotFoundError:
try:
logger.debug(f"{folder_name}/{filename} is being downloaded from {known_file.repo_id}/{known_file.filename} candidate_str_match={candidate_str_match} candidate_filename_match={candidate_filename_match} candidate_alternate_filenames_match={candidate_alternate_filenames_match} candidate_save_filename_match={candidate_save_filename_match}")
path = hf_hub_download(repo_id=known_file.repo_id,
filename=known_file.filename,
repo_type=known_file.repo_type,
revision=known_file.revision,
local_dir=hf_destination_dir if args.force_hf_local_dir_mode else None,
)
path = hf_hub_download_with_retries(repo_id=known_file.repo_id,
filename=known_file.filename,
repo_type=known_file.repo_type,
revision=known_file.revision,
watcher=watcher,
local_dir=hf_destination_dir if args.force_hf_local_dir_mode else None,
)
except IOError as exc_info:
logger.error(f"cannot reach huggingface {known_file.repo_id}/{known_file.filename}", exc_info=exc_info)
except Exception as exc_info:
logger.error(f"an exception occurred while downloading {known_file.repo_id}/{known_file.filename}", exc_info=exc_info)
dump_environment_info()
for key, value in os.environ.items():
if key.startswith("HF_XET"):
print(f"{key}={value}", file=sys.stderr)
if path is not None and known_file.convert_to_16_bit and file_size is not None and file_size != 0:
tensors = {}

View File

@ -8,6 +8,7 @@ import os
import re
import traceback
import zipfile
import logging
from pathlib import Path
try:
@ -26,6 +27,7 @@ from .component_model import files
from .component_model.files import get_path_as_dict, get_package_as_path
from .text_encoders.spiece_tokenizer import SPieceTokenizer
logger = logging.getLogger(__name__)
def gen_empty_tokens(special_tokens, length):
start_token = special_tokens.get("start", None)
@ -119,6 +121,10 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if textmodel_json_config is None and "model_name" not in model_options:
model_options = {**model_options, "model_name": "clip_l"}
if "model_name" in model_options and "clip" not in model_options["model_name"].lower() and textmodel_json_config is None:
logger.warning(f"Text encoder {model_options["model_name"]} provided a None textmodel_json_config, when it should have been an empty dict")
textmodel_json_config = {}
config = get_path_as_dict(textmodel_json_config, "sd1_clip_config.json", package=__package__)
te_model_options = model_options.get("{}_model_config".format(model_options.get("model_name", "")), {})

View File

@ -30,7 +30,7 @@ class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
class LLAMAModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options=None, special_tokens=None):
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options=None, special_tokens=None, textmodel_json_config=None):
if special_tokens is None:
special_tokens = {"start": 128000, "pad": 128258}
if model_options is None:
@ -40,7 +40,7 @@ class LLAMAModel(sd1_clip.SDClipModel):
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
textmodel_json_config = {}
textmodel_json_config = textmodel_json_config or {}
vocab_size = model_options.get("vocab_size", None)
if vocab_size is not None:
textmodel_json_config["vocab_size"] = vocab_size

View File

@ -1,8 +1,7 @@
from dataclasses import dataclass
from typing import Optional, Any
import torch
import torch.nn as nn
from dataclasses import dataclass
from typing import Optional, Any
from ..ldm.common_dit import rms_norm
from ..ldm.modules.attention import optimized_attention_for_device
@ -25,6 +24,7 @@ class Llama2Config:
mlp_activation = "silu"
qkv_bias = False
@dataclass
class Qwen25_3BConfig:
vocab_size: int = 151936
@ -60,6 +60,7 @@ class Qwen25_7BVLI_Config:
mlp_activation = "silu"
qkv_bias = True
@dataclass
class Gemma2_2B_Config:
vocab_size: int = 256000
@ -363,6 +364,7 @@ class Llama2(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen25_3B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
@ -372,6 +374,7 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
@ -381,6 +384,7 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Gemma2_2B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()

View File

@ -25,7 +25,8 @@ class Gemma2_2BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options=None, textmodel_json_config=None):
if model_options is None:
model_options = {}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
textmodel_json_config = textmodel_json_config or {}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class LuminaModel(sd1_clip.SD1ClipModel):

View File

@ -29,9 +29,10 @@ class Omnigen2Tokenizer(sd1_clip.SD1Tokenizer):
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
class Qwen25_3BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options=None):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options=None, textmodel_json_config=None):
if model_options is None:
model_options = {}
textmodel_json_config = textmodel_json_config or {}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_3B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)

View File

@ -31,10 +31,11 @@ class QwenImageTokenizer(sd1_clip.SD1Tokenizer):
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options=None):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options=None, textmodel_json_config=None):
if model_options is None:
model_options = {}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
textmodel_json_config = textmodel_json_config or {}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class QwenImageTEModel(sd1_clip.SD1ClipModel):
@ -82,4 +83,4 @@ def te(dtype_llama=None, llama_scaled_fp8=None):
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return QwenImageTEModel_
return QwenImageTEModel_

View File

@ -31,7 +31,7 @@ import warnings
from contextlib import contextmanager
from pathlib import Path
from pickle import UnpicklingError
from typing import Optional, Any, Literal
from typing import Optional, Any, Literal, Generator
import numpy as np
import safetensors.torch
@ -48,7 +48,8 @@ from .cli_args import args
from .component_model import files
from .component_model.deprecation import _deprecate_method
from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage
from .execution_context import current_execution_context, ExecutionContext
from .component_model.tqdm_watcher import TqdmWatcher
from .execution_context import current_execution_context
from .gguf import gguf_sd_loader
MMAP_TORCH_FILES = args.mmap_torch_files
@ -115,9 +116,9 @@ def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=Fal
if len(e.args) > 0:
message = e.args[0]
if "HeaderTooLarge" in message:
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt or invalid. Make sure this is actually a safetensors file and not a ckpt or pt or other filetype.".format(message, ckpt))
if "MetadataIncompleteBuffer" in message:
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.".format(message, ckpt))
raise ValueError(f"{message} (File path: {ckpt} The safetensors file is corrupt or invalid. Make sure this is actually a safetensors file and not a ckpt or pt or other filetype.")
if "MetadataIncompleteBuffer" in message or "InvalidHeaderDeserialization" in message:
raise ValueError(f"{message} (File path: {ckpt} The safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.")
raise e
elif ckpt.lower().endswith("index.json"):
# from accelerate
@ -1191,42 +1192,49 @@ def get_project_root() -> str:
@contextmanager
def comfy_tqdm():
def comfy_tqdm() -> Generator[TqdmWatcher, None, None]:
"""
Monky patches child calls to tqdm and sends the progress to the UI
:return:
Monkey patches child calls to tqdm, sends progress to the UI,
and yields a watcher object for stall detection.
"""
_original_init = tqdm.__init__
_original_call = tqdm.__call__
_original_update = tqdm.update
# Create the watcher instance that the patched methods will update
# and that will be yielded to the caller.
watcher = TqdmWatcher()
context = contextvars.copy_context()
try:
# These inner functions are closures; they capture the `watcher` variable
# from the enclosing scope.
def __init(self, *args, **kwargs):
context.run(lambda: _original_init(self, *args, **kwargs))
self._progress_bar = context.run(lambda: ProgressBar(self.total))
watcher.tick() # Signal progress on initialization
def __update(self, n=1):
assert self._progress_bar is not None
context.run(lambda: _original_update(self, n))
context.run(lambda: self._progress_bar.update(n))
watcher.tick() # Signal progress on update
def __call(self, *args, **kwargs):
# When TQDM is called to wrap an iterable, ensure the instance is created
# with the captured context
instance = context.run(lambda: _original_call(self, *args, **kwargs))
return instance
tqdm.__init__ = __init
tqdm.__call__ = __call
tqdm.update = __update
# todo: modify the tqdm class here to correctly copy the context into the function that tqdm is passed
yield
yield watcher
finally:
# Restore original tqdm
tqdm.__init__ = _original_init
tqdm.__call__ = _original_call
tqdm.update = _original_update
# todo: restore the context copying away
@contextmanager

View File

@ -11,12 +11,13 @@ from comfy.model_downloader import get_filename_list_with_downloadable, KNOWN_UP
from comfy.model_management import load_models_gpu
from comfy.model_management_types import ModelManageable
logger = logging.getLogger(__name__)
try:
from spandrel_extra_arches import EXTRA_REGISTRY # pylint: disable=import-error
from spandrel import MAIN_REGISTRY
MAIN_REGISTRY.add(*EXTRA_REGISTRY)
logging.debug("Successfully imported spandrel_extra_arches: support for non commercial upscale models.")
logger.debug("Successfully imported spandrel_extra_arches: support for non commercial upscale models.")
except:
pass

View File

@ -0,0 +1,141 @@
{
"3": {
"inputs": {
"seed": 918331849236269,
"steps": 1,
"cfg": 1,
"sampler_name": "res_multistep",
"scheduler": "normal",
"denoise": 1,
"model": [
"66",
0
],
"positive": [
"6",
0
],
"negative": [
"7",
0
],
"latent_image": [
"58",
0
]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"6": {
"inputs": {
"text": "cute anime girl with massive fennec ears and a big fluffy fox tail with long wavy blonde hair between eyes and large blue eyes blonde colored eyelashes chubby wearing oversized clothes summer uniform long blue maxi skirt muddy clothes happy sitting on the side of the road in a run down dark gritty cyberpunk city with neon and a crumbling skyscraper in the rain at night while dipping her feet in a river of water she is holding a sign that says \"ComfyUI is the best\" written in cursive",
"clip": [
"38",
0
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Positive Prompt)"
}
},
"7": {
"inputs": {
"text": " ",
"clip": [
"38",
0
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Negative Prompt)"
}
},
"8": {
"inputs": {
"samples": [
"3",
0
],
"vae": [
"39",
0
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"37": {
"inputs": {
"unet_name": "qwen-image-Q4_K_M.gguf",
"weight_dtype": "default"
},
"class_type": "UNETLoader",
"_meta": {
"title": "Load Diffusion Model"
}
},
"38": {
"inputs": {
"clip_name": "qwen_2.5_vl_7b.safetensors",
"type": "qwen_image",
"device": "default"
},
"class_type": "CLIPLoader",
"_meta": {
"title": "Load CLIP"
}
},
"39": {
"inputs": {
"vae_name": "qwen_image_vae.safetensors"
},
"class_type": "VAELoader",
"_meta": {
"title": "Load VAE"
}
},
"58": {
"inputs": {
"width": 1328,
"height": 1328,
"batch_size": 1
},
"class_type": "EmptySD3LatentImage",
"_meta": {
"title": "EmptySD3LatentImage"
}
},
"60": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"8",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
},
"66": {
"inputs": {
"shift": 3.1000000000000005,
"model": [
"37",
0
]
},
"class_type": "ModelSamplingAuraFlow",
"_meta": {
"title": "ModelSamplingAuraFlow"
}
}
}

View File

@ -0,0 +1,79 @@
import threading
import time
import pytest
from comfy.component_model.hf_hub_download_with_disable_xet import hf_hub_download_with_retries
from comfy.component_model.tqdm_watcher import TqdmWatcher
download_method_name = "comfy.component_model.hf_hub_download_with_disable_xet.hf_hub_download_with_disable_fast"
def mock_stalled_download(*args, **kwargs):
"""A mock for hf_hub_download that simulates a stall by sleeping indefinitely."""
time.sleep(10)
return "this_path_should_never_be_returned"
def test_download_stalls_and_fails(monkeypatch):
"""
Verify that a stalled download triggers retries and eventually fails with an RuntimeError.
"""
monkeypatch.setattr(download_method_name, mock_stalled_download)
watcher = TqdmWatcher()
repo_id = "test/repo-stall"
filename = "stalled_file.safetensors"
with pytest.raises(RuntimeError) as excinfo:
hf_hub_download_with_retries(
repo_id=repo_id,
filename=filename,
watcher=watcher,
stall_timeout=0.2,
retries=2,
)
assert f"Failed to download '{repo_id}/{filename}' after 2 attempts" in str(excinfo.value)
def mock_successful_slow_download(*args, **kwargs):
"""A mock for a download that is slow but not stalled."""
time.sleep(1)
return "expected/successful/path"
def _keep_watcher_alive(watcher: TqdmWatcher, stop_event: threading.Event):
"""Helper function to run in a thread and periodically tick the watcher."""
while not stop_event.is_set():
watcher.tick()
time.sleep(0.1)
def test_download_progresses_and_succeeds(monkeypatch):
"""
Verify that a download with periodic progress updates completes successfully.
"""
monkeypatch.setattr(download_method_name, mock_successful_slow_download)
watcher = TqdmWatcher()
stop_event = threading.Event()
ticker_thread = threading.Thread(
target=_keep_watcher_alive,
args=(watcher, stop_event),
daemon=True
)
ticker_thread.start()
try:
result = hf_hub_download_with_retries(
repo_id="test/repo-success",
filename="good_file.safetensors",
stall_timeout=0.3,
watcher=watcher
)
assert result == "expected/successful/path"
finally:
stop_event.set()
ticker_thread.join(timeout=1)

View File

@ -74,7 +74,7 @@ def mock_openai_client():
instance.images.generate = Mock()
yield instance
@pytest.mark.skip("broken transformers")
def test_transformers_loader(has_gpu):
if not has_gpu:
pytest.skip("requires GPU")