mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Qwen Image with sage attention workarounds
This commit is contained in:
parent
b72e5ff448
commit
cd17b42664
@ -18,7 +18,6 @@ from .. import options
|
||||
from ..app import logger
|
||||
|
||||
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
||||
os.environ["HF_XET_HIGH_PERFORMANCE"] = "True"
|
||||
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
|
||||
os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1"
|
||||
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
|
||||
|
||||
119
comfy/component_model/hf_hub_download_with_disable_xet.py
Normal file
119
comfy/component_model/hf_hub_download_with_disable_xet.py
Normal file
@ -0,0 +1,119 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from concurrent.futures import Future
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import filelock
|
||||
import huggingface_hub
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub import logging as hf_logging
|
||||
|
||||
hf_logging.set_verbosity_debug()
|
||||
from pebble import ThreadPool
|
||||
|
||||
from .tqdm_watcher import TqdmWatcher
|
||||
|
||||
_VAR = "HF_HUB_ENABLE_HF_TRANSFER"
|
||||
_XET_VAR = "HF_XET_HIGH_PERFORMANCE"
|
||||
|
||||
os.environ[_VAR] = "True"
|
||||
|
||||
os.environ["HF_HUB_DISABLE_XET"] = "1"
|
||||
# os.environ["HF_XET_HIGH_PERFORMANCE"] = "True"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.debug("Xet was disabled since it is currently not reliable")
|
||||
|
||||
|
||||
def hf_hub_download_with_disable_fast(repo_id=None, filename=None, disable_fast=None, hf_env: dict[str, str] = None, **kwargs):
|
||||
for k, v in hf_env.items():
|
||||
os.environ[k] = v
|
||||
if disable_fast:
|
||||
if _VAR == _XET_VAR:
|
||||
os.environ["HF_HUB_DISABLE_XET"] = "1"
|
||||
else:
|
||||
os.environ[_VAR] = "False"
|
||||
return hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
|
||||
|
||||
|
||||
def hf_hub_download_with_retries(repo_id: str, filename: str, watcher: Optional[TqdmWatcher] = None, retries=2, stall_timeout=10, **kwargs):
|
||||
"""
|
||||
Wraps hf_hub_download with stall detection and retries using a TqdmWatcher.
|
||||
Includes a monkey-patch for filelock to release locks from stalled downloads.
|
||||
"""
|
||||
if watcher is None:
|
||||
logger.warning(f"called _hf_hub_download_with_retries without progress to watch")
|
||||
return hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
|
||||
|
||||
xet_available = huggingface_hub.file_download.is_xet_available()
|
||||
hf_hub_disable_xet_prev_value = os.getenv("HF_HUB_DISABLE_XET")
|
||||
disable_fast = hf_hub_disable_xet_prev_value is not None
|
||||
|
||||
instantiated_locks: list[filelock.FileLock] = []
|
||||
original_filelock_init = filelock.FileLock.__init__
|
||||
|
||||
def new_filelock_init(self, *args, **kwargs):
|
||||
"""A wrapper around FileLock.__init__ to capture lock instances."""
|
||||
original_filelock_init(self, *args, **kwargs)
|
||||
instantiated_locks.append(self)
|
||||
|
||||
filelock.FileLock.__init__ = new_filelock_init
|
||||
|
||||
try:
|
||||
with ThreadPool(max_workers=retries + 1) as executor:
|
||||
for attempt in range(retries):
|
||||
watcher.tick()
|
||||
hf_env = {k: v for k, v in os.environ.items() if k.upper().startswith("HF_")}
|
||||
|
||||
if len(instantiated_locks) > 0:
|
||||
logger.debug(f"Attempting to unlock {len(instantiated_locks)} captured file locks.")
|
||||
for lock in instantiated_locks:
|
||||
path = lock.lock_file
|
||||
if lock.is_locked:
|
||||
lock.release(force=True)
|
||||
else:
|
||||
# something else went wrong
|
||||
try:
|
||||
lock._release()
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
try:
|
||||
Path(path).unlink(missing_ok=True)
|
||||
except OSError:
|
||||
# todo: obviously the process is holding this lock
|
||||
pass
|
||||
logger.debug(f"Released stalled lock: {lock.lock_file}")
|
||||
instantiated_locks.clear()
|
||||
future: Future[str] = executor.submit(hf_hub_download_with_disable_fast, repo_id=repo_id, filename=filename, disable_fast=disable_fast, hf_env=hf_env, **kwargs)
|
||||
|
||||
try:
|
||||
while not future.done():
|
||||
if time.monotonic() - watcher.last_update_time > stall_timeout:
|
||||
msg = f"Download of '{repo_id}/{filename}' stalled for >{stall_timeout}s. Retrying... (Attempt {attempt + 1}/{retries})"
|
||||
if xet_available:
|
||||
logger.warning(f"{msg}. Disabling xet for our retry.")
|
||||
disable_fast = True
|
||||
else:
|
||||
logger.warning(msg)
|
||||
|
||||
future.cancel() # Cancel the stalled future
|
||||
break
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
if future.done() and not future.cancelled():
|
||||
return future.result()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during download attempt {attempt + 1}: {e}", exc_info=True)
|
||||
|
||||
raise RuntimeError(f"Failed to download '{repo_id}/{filename}' after {retries} attempts.")
|
||||
finally:
|
||||
filelock.FileLock.__init__ = original_filelock_init
|
||||
|
||||
if hf_hub_disable_xet_prev_value is not None:
|
||||
os.environ["HF_HUB_DISABLE_XET"] = hf_hub_disable_xet_prev_value
|
||||
20
comfy/component_model/tqdm_watcher.py
Normal file
20
comfy/component_model/tqdm_watcher.py
Normal file
@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
|
||||
class TqdmWatcher:
|
||||
"""An object to track the progress of a tqdm instance."""
|
||||
|
||||
def __init__(self):
|
||||
# We use a list to store the time, making it mutable across scopes.
|
||||
self._last_update = [time.monotonic()]
|
||||
|
||||
def tick(self):
|
||||
"""Signals that progress has been made by updating the timestamp."""
|
||||
self._last_update[0] = time.monotonic()
|
||||
|
||||
@property
|
||||
def last_update_time(self) -> float:
|
||||
"""Gets the time of the last recorded progress update."""
|
||||
return self._last_update[0]
|
||||
@ -36,8 +36,8 @@ REARRANGE_THRESHOLD = 512
|
||||
MAX_TENSOR_NAME_LENGTH = 127
|
||||
MAX_TENSOR_DIMS = 4
|
||||
TORCH_COMPATIBLE_QTYPES = (None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16)
|
||||
IMG_ARCH_LIST = {"flux", "sd1", "sdxl", "sd3", "aura", "hidream", "cosmos", "ltxv", "hyvid", "wan"}
|
||||
TXT_ARCH_LIST = {"t5", "t5encoder", "llama"}
|
||||
IMG_ARCH_LIST = {"flux", "sd1", "sdxl", "sd3", "aura", "hidream", "cosmos", "ltxv", "hyvid", "wan", "qwen_image"}
|
||||
TXT_ARCH_LIST = {"t5", "t5encoder", "llama", "qwen2vl"}
|
||||
|
||||
|
||||
class ModelTemplate:
|
||||
@ -739,9 +739,9 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=Fal
|
||||
except Exception as e:
|
||||
raise ValueError(f"This model is not currently supported - ({e})")
|
||||
elif arch_str not in TXT_ARCH_LIST and is_text_model:
|
||||
raise ValueError(f"Unexpected text model architecture type in GGUF file: {arch_str!r}")
|
||||
logger.warning(f"Unexpected text model architecture type in GGUF file: {arch_str!r}")
|
||||
elif arch_str not in IMG_ARCH_LIST and not is_text_model:
|
||||
raise ValueError(f"Unexpected architecture type in GGUF file: {arch_str!r}")
|
||||
logger.warning(f"Unexpected architecture type in GGUF file: {arch_str!r}")
|
||||
|
||||
if compat:
|
||||
logger.warning(f"Warning: This gguf model file is loaded in compatibility mode '{compat}' [arch:{arch_str}]")
|
||||
@ -903,7 +903,7 @@ def gguf_clip_loader(path):
|
||||
logger.warning(f"Dequantizing {temb_key} to prevent runtime OOM.")
|
||||
sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16)
|
||||
sd = sd_map_replace(sd, T5_SD_MAP)
|
||||
elif arch in {"llama"}:
|
||||
elif arch in {"llama", "qwen2vl"}:
|
||||
# TODO: pass model_options["vocab_size"] to loader somehow
|
||||
temb_key = "token_embd.weight"
|
||||
if temb_key in sd and sd[temb_key].shape[0] >= (64 * 1024):
|
||||
@ -911,7 +911,8 @@ def gguf_clip_loader(path):
|
||||
logger.warning(f"Dequantizing {temb_key} to prevent runtime OOM.")
|
||||
sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16)
|
||||
sd = sd_map_replace(sd, LLAMA_SD_MAP)
|
||||
sd = llama_permute(sd, 32, 8) # L3
|
||||
if arch == "llama":
|
||||
sd = llama_permute(sd, 32, 8) # L3
|
||||
else:
|
||||
pass
|
||||
return sd
|
||||
|
||||
@ -605,11 +605,13 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
return out
|
||||
|
||||
|
||||
optimized_attention = attention_basic
|
||||
optimized_attention = attention_pytorch
|
||||
optimized_attention_no_sage = attention_pytorch
|
||||
|
||||
if model_management.sage_attention_enabled():
|
||||
logger.debug("Using sage attention")
|
||||
optimized_attention = attention_sage
|
||||
optimized_attention_no_sage = attention_pytorch
|
||||
elif model_management.xformers_enabled():
|
||||
logger.debug("Using xformers attention")
|
||||
optimized_attention = attention_xformers
|
||||
@ -628,6 +630,7 @@ else:
|
||||
optimized_attention = attention_sub_quad
|
||||
|
||||
optimized_attention_masked = optimized_attention
|
||||
optimized_attention_no_sage_masked = optimized_attention_no_sage
|
||||
|
||||
|
||||
def optimized_attention_for_device(device, mask=False, small_input=False):
|
||||
|
||||
@ -8,7 +8,7 @@ from typing import Optional, Tuple
|
||||
from ..common_dit import pad_to_patch_size
|
||||
from ..flux.layers import EmbedND
|
||||
from ..lightricks.model import TimestepEmbedding, Timesteps
|
||||
from ..modules.attention import optimized_attention_masked
|
||||
from ..modules.attention import optimized_attention_no_sage_masked as optimized_attention_masked
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
|
||||
@ -5,6 +5,7 @@ import logging
|
||||
import operator
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from collections.abc import Sequence, MutableSequence
|
||||
from functools import reduce
|
||||
from itertools import chain
|
||||
@ -12,16 +13,15 @@ from os.path import join
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Final, Set
|
||||
|
||||
# enable better transfer
|
||||
os.environ["HF_XET_HIGH_PERFORMANCE"] = "True"
|
||||
from .component_model.hf_hub_download_with_disable_xet import hf_hub_download_with_retries
|
||||
|
||||
import tqdm
|
||||
from huggingface_hub import hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem
|
||||
from huggingface_hub.file_download import are_symlinks_supported
|
||||
from huggingface_hub.utils import GatedRepoError, LocalEntryNotFoundError
|
||||
from requests import Session
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
from huggingface_hub import dump_environment_info
|
||||
|
||||
from .cli_args import args
|
||||
from .cmd import folder_paths
|
||||
@ -92,7 +92,7 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
|
||||
if known_file is None:
|
||||
logger.debug(f"get_or_download could not find {filename} in {folder_name}, known_files={known_files}")
|
||||
return path
|
||||
with comfy_tqdm():
|
||||
with comfy_tqdm() as watcher:
|
||||
if isinstance(known_file, HuggingFile):
|
||||
if known_file.save_with_filename is not None:
|
||||
linked_filename = known_file.save_with_filename
|
||||
@ -139,14 +139,21 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
|
||||
except LocalEntryNotFoundError:
|
||||
try:
|
||||
logger.debug(f"{folder_name}/{filename} is being downloaded from {known_file.repo_id}/{known_file.filename} candidate_str_match={candidate_str_match} candidate_filename_match={candidate_filename_match} candidate_alternate_filenames_match={candidate_alternate_filenames_match} candidate_save_filename_match={candidate_save_filename_match}")
|
||||
path = hf_hub_download(repo_id=known_file.repo_id,
|
||||
filename=known_file.filename,
|
||||
repo_type=known_file.repo_type,
|
||||
revision=known_file.revision,
|
||||
local_dir=hf_destination_dir if args.force_hf_local_dir_mode else None,
|
||||
)
|
||||
path = hf_hub_download_with_retries(repo_id=known_file.repo_id,
|
||||
filename=known_file.filename,
|
||||
repo_type=known_file.repo_type,
|
||||
revision=known_file.revision,
|
||||
watcher=watcher,
|
||||
local_dir=hf_destination_dir if args.force_hf_local_dir_mode else None,
|
||||
)
|
||||
except IOError as exc_info:
|
||||
logger.error(f"cannot reach huggingface {known_file.repo_id}/{known_file.filename}", exc_info=exc_info)
|
||||
except Exception as exc_info:
|
||||
logger.error(f"an exception occurred while downloading {known_file.repo_id}/{known_file.filename}", exc_info=exc_info)
|
||||
dump_environment_info()
|
||||
for key, value in os.environ.items():
|
||||
if key.startswith("HF_XET"):
|
||||
print(f"{key}={value}", file=sys.stderr)
|
||||
|
||||
if path is not None and known_file.convert_to_16_bit and file_size is not None and file_size != 0:
|
||||
tensors = {}
|
||||
|
||||
@ -8,6 +8,7 @@ import os
|
||||
import re
|
||||
import traceback
|
||||
import zipfile
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
@ -26,6 +27,7 @@ from .component_model import files
|
||||
from .component_model.files import get_path_as_dict, get_package_as_path
|
||||
from .text_encoders.spiece_tokenizer import SPieceTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def gen_empty_tokens(special_tokens, length):
|
||||
start_token = special_tokens.get("start", None)
|
||||
@ -119,6 +121,10 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
if textmodel_json_config is None and "model_name" not in model_options:
|
||||
model_options = {**model_options, "model_name": "clip_l"}
|
||||
|
||||
if "model_name" in model_options and "clip" not in model_options["model_name"].lower() and textmodel_json_config is None:
|
||||
logger.warning(f"Text encoder {model_options["model_name"]} provided a None textmodel_json_config, when it should have been an empty dict")
|
||||
textmodel_json_config = {}
|
||||
|
||||
config = get_path_as_dict(textmodel_json_config, "sd1_clip_config.json", package=__package__)
|
||||
|
||||
te_model_options = model_options.get("{}_model_config".format(model_options.get("model_name", "")), {})
|
||||
|
||||
@ -30,7 +30,7 @@ class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
|
||||
|
||||
|
||||
class LLAMAModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options=None, special_tokens=None):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options=None, special_tokens=None, textmodel_json_config=None):
|
||||
if special_tokens is None:
|
||||
special_tokens = {"start": 128000, "pad": 128258}
|
||||
if model_options is None:
|
||||
@ -40,7 +40,7 @@ class LLAMAModel(sd1_clip.SDClipModel):
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
|
||||
textmodel_json_config = {}
|
||||
textmodel_json_config = textmodel_json_config or {}
|
||||
vocab_size = model_options.get("vocab_size", None)
|
||||
if vocab_size is not None:
|
||||
textmodel_json_config["vocab_size"] = vocab_size
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Any
|
||||
|
||||
from ..ldm.common_dit import rms_norm
|
||||
from ..ldm.modules.attention import optimized_attention_for_device
|
||||
@ -25,6 +24,7 @@ class Llama2Config:
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class Qwen25_3BConfig:
|
||||
vocab_size: int = 151936
|
||||
@ -60,6 +60,7 @@ class Qwen25_7BVLI_Config:
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class Gemma2_2B_Config:
|
||||
vocab_size: int = 256000
|
||||
@ -363,6 +364,7 @@ class Llama2(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
|
||||
class Qwen25_3B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
@ -372,6 +374,7 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
|
||||
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
@ -381,6 +384,7 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
|
||||
class Gemma2_2B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
|
||||
@ -25,7 +25,8 @@ class Gemma2_2BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options=None, textmodel_json_config=None):
|
||||
if model_options is None:
|
||||
model_options = {}
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
textmodel_json_config = textmodel_json_config or {}
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
class LuminaModel(sd1_clip.SD1ClipModel):
|
||||
|
||||
@ -29,9 +29,10 @@ class Omnigen2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
|
||||
|
||||
class Qwen25_3BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options=None):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options=None, textmodel_json_config=None):
|
||||
if model_options is None:
|
||||
model_options = {}
|
||||
textmodel_json_config = textmodel_json_config or {}
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_3B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
|
||||
@ -31,10 +31,11 @@ class QwenImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
|
||||
|
||||
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options=None):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options=None, textmodel_json_config=None):
|
||||
if model_options is None:
|
||||
model_options = {}
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
textmodel_json_config = textmodel_json_config or {}
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
class QwenImageTEModel(sd1_clip.SD1ClipModel):
|
||||
@ -82,4 +83,4 @@ def te(dtype_llama=None, llama_scaled_fp8=None):
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
|
||||
return QwenImageTEModel_
|
||||
return QwenImageTEModel_
|
||||
|
||||
@ -31,7 +31,7 @@ import warnings
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from pickle import UnpicklingError
|
||||
from typing import Optional, Any, Literal
|
||||
from typing import Optional, Any, Literal, Generator
|
||||
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
@ -48,7 +48,8 @@ from .cli_args import args
|
||||
from .component_model import files
|
||||
from .component_model.deprecation import _deprecate_method
|
||||
from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage
|
||||
from .execution_context import current_execution_context, ExecutionContext
|
||||
from .component_model.tqdm_watcher import TqdmWatcher
|
||||
from .execution_context import current_execution_context
|
||||
from .gguf import gguf_sd_loader
|
||||
|
||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||
@ -115,9 +116,9 @@ def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=Fal
|
||||
if len(e.args) > 0:
|
||||
message = e.args[0]
|
||||
if "HeaderTooLarge" in message:
|
||||
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt or invalid. Make sure this is actually a safetensors file and not a ckpt or pt or other filetype.".format(message, ckpt))
|
||||
if "MetadataIncompleteBuffer" in message:
|
||||
raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.".format(message, ckpt))
|
||||
raise ValueError(f"{message} (File path: {ckpt} The safetensors file is corrupt or invalid. Make sure this is actually a safetensors file and not a ckpt or pt or other filetype.")
|
||||
if "MetadataIncompleteBuffer" in message or "InvalidHeaderDeserialization" in message:
|
||||
raise ValueError(f"{message} (File path: {ckpt} The safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.")
|
||||
raise e
|
||||
elif ckpt.lower().endswith("index.json"):
|
||||
# from accelerate
|
||||
@ -1191,42 +1192,49 @@ def get_project_root() -> str:
|
||||
|
||||
|
||||
@contextmanager
|
||||
def comfy_tqdm():
|
||||
def comfy_tqdm() -> Generator[TqdmWatcher, None, None]:
|
||||
"""
|
||||
Monky patches child calls to tqdm and sends the progress to the UI
|
||||
:return:
|
||||
Monkey patches child calls to tqdm, sends progress to the UI,
|
||||
and yields a watcher object for stall detection.
|
||||
"""
|
||||
_original_init = tqdm.__init__
|
||||
_original_call = tqdm.__call__
|
||||
_original_update = tqdm.update
|
||||
|
||||
# Create the watcher instance that the patched methods will update
|
||||
# and that will be yielded to the caller.
|
||||
watcher = TqdmWatcher()
|
||||
context = contextvars.copy_context()
|
||||
|
||||
try:
|
||||
# These inner functions are closures; they capture the `watcher` variable
|
||||
# from the enclosing scope.
|
||||
def __init(self, *args, **kwargs):
|
||||
context.run(lambda: _original_init(self, *args, **kwargs))
|
||||
self._progress_bar = context.run(lambda: ProgressBar(self.total))
|
||||
watcher.tick() # Signal progress on initialization
|
||||
|
||||
def __update(self, n=1):
|
||||
assert self._progress_bar is not None
|
||||
context.run(lambda: _original_update(self, n))
|
||||
context.run(lambda: self._progress_bar.update(n))
|
||||
watcher.tick() # Signal progress on update
|
||||
|
||||
def __call(self, *args, **kwargs):
|
||||
# When TQDM is called to wrap an iterable, ensure the instance is created
|
||||
# with the captured context
|
||||
instance = context.run(lambda: _original_call(self, *args, **kwargs))
|
||||
return instance
|
||||
|
||||
tqdm.__init__ = __init
|
||||
tqdm.__call__ = __call
|
||||
tqdm.update = __update
|
||||
# todo: modify the tqdm class here to correctly copy the context into the function that tqdm is passed
|
||||
yield
|
||||
|
||||
yield watcher
|
||||
|
||||
finally:
|
||||
# Restore original tqdm
|
||||
tqdm.__init__ = _original_init
|
||||
tqdm.__call__ = _original_call
|
||||
tqdm.update = _original_update
|
||||
# todo: restore the context copying away
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
||||
@ -11,12 +11,13 @@ from comfy.model_downloader import get_filename_list_with_downloadable, KNOWN_UP
|
||||
from comfy.model_management import load_models_gpu
|
||||
from comfy.model_management_types import ModelManageable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
try:
|
||||
from spandrel_extra_arches import EXTRA_REGISTRY # pylint: disable=import-error
|
||||
from spandrel import MAIN_REGISTRY
|
||||
|
||||
MAIN_REGISTRY.add(*EXTRA_REGISTRY)
|
||||
logging.debug("Successfully imported spandrel_extra_arches: support for non commercial upscale models.")
|
||||
logger.debug("Successfully imported spandrel_extra_arches: support for non commercial upscale models.")
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
141
tests/inference/workflows/qwen-image-gguf4-0.json
Normal file
141
tests/inference/workflows/qwen-image-gguf4-0.json
Normal file
@ -0,0 +1,141 @@
|
||||
{
|
||||
"3": {
|
||||
"inputs": {
|
||||
"seed": 918331849236269,
|
||||
"steps": 1,
|
||||
"cfg": 1,
|
||||
"sampler_name": "res_multistep",
|
||||
"scheduler": "normal",
|
||||
"denoise": 1,
|
||||
"model": [
|
||||
"66",
|
||||
0
|
||||
],
|
||||
"positive": [
|
||||
"6",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"7",
|
||||
0
|
||||
],
|
||||
"latent_image": [
|
||||
"58",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "KSampler",
|
||||
"_meta": {
|
||||
"title": "KSampler"
|
||||
}
|
||||
},
|
||||
"6": {
|
||||
"inputs": {
|
||||
"text": "cute anime girl with massive fennec ears and a big fluffy fox tail with long wavy blonde hair between eyes and large blue eyes blonde colored eyelashes chubby wearing oversized clothes summer uniform long blue maxi skirt muddy clothes happy sitting on the side of the road in a run down dark gritty cyberpunk city with neon and a crumbling skyscraper in the rain at night while dipping her feet in a river of water she is holding a sign that says \"ComfyUI is the best\" written in cursive",
|
||||
"clip": [
|
||||
"38",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Positive Prompt)"
|
||||
}
|
||||
},
|
||||
"7": {
|
||||
"inputs": {
|
||||
"text": " ",
|
||||
"clip": [
|
||||
"38",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Negative Prompt)"
|
||||
}
|
||||
},
|
||||
"8": {
|
||||
"inputs": {
|
||||
"samples": [
|
||||
"3",
|
||||
0
|
||||
],
|
||||
"vae": [
|
||||
"39",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "VAEDecode",
|
||||
"_meta": {
|
||||
"title": "VAE Decode"
|
||||
}
|
||||
},
|
||||
"37": {
|
||||
"inputs": {
|
||||
"unet_name": "qwen-image-Q4_K_M.gguf",
|
||||
"weight_dtype": "default"
|
||||
},
|
||||
"class_type": "UNETLoader",
|
||||
"_meta": {
|
||||
"title": "Load Diffusion Model"
|
||||
}
|
||||
},
|
||||
"38": {
|
||||
"inputs": {
|
||||
"clip_name": "qwen_2.5_vl_7b.safetensors",
|
||||
"type": "qwen_image",
|
||||
"device": "default"
|
||||
},
|
||||
"class_type": "CLIPLoader",
|
||||
"_meta": {
|
||||
"title": "Load CLIP"
|
||||
}
|
||||
},
|
||||
"39": {
|
||||
"inputs": {
|
||||
"vae_name": "qwen_image_vae.safetensors"
|
||||
},
|
||||
"class_type": "VAELoader",
|
||||
"_meta": {
|
||||
"title": "Load VAE"
|
||||
}
|
||||
},
|
||||
"58": {
|
||||
"inputs": {
|
||||
"width": 1328,
|
||||
"height": 1328,
|
||||
"batch_size": 1
|
||||
},
|
||||
"class_type": "EmptySD3LatentImage",
|
||||
"_meta": {
|
||||
"title": "EmptySD3LatentImage"
|
||||
}
|
||||
},
|
||||
"60": {
|
||||
"inputs": {
|
||||
"filename_prefix": "ComfyUI",
|
||||
"images": [
|
||||
"8",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SaveImage",
|
||||
"_meta": {
|
||||
"title": "Save Image"
|
||||
}
|
||||
},
|
||||
"66": {
|
||||
"inputs": {
|
||||
"shift": 3.1000000000000005,
|
||||
"model": [
|
||||
"37",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "ModelSamplingAuraFlow",
|
||||
"_meta": {
|
||||
"title": "ModelSamplingAuraFlow"
|
||||
}
|
||||
}
|
||||
}
|
||||
79
tests/unit/test_download_bailout.py
Normal file
79
tests/unit/test_download_bailout.py
Normal file
@ -0,0 +1,79 @@
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from comfy.component_model.hf_hub_download_with_disable_xet import hf_hub_download_with_retries
|
||||
from comfy.component_model.tqdm_watcher import TqdmWatcher
|
||||
|
||||
download_method_name = "comfy.component_model.hf_hub_download_with_disable_xet.hf_hub_download_with_disable_fast"
|
||||
|
||||
def mock_stalled_download(*args, **kwargs):
|
||||
"""A mock for hf_hub_download that simulates a stall by sleeping indefinitely."""
|
||||
time.sleep(10)
|
||||
return "this_path_should_never_be_returned"
|
||||
|
||||
|
||||
def test_download_stalls_and_fails(monkeypatch):
|
||||
"""
|
||||
Verify that a stalled download triggers retries and eventually fails with an RuntimeError.
|
||||
"""
|
||||
|
||||
monkeypatch.setattr(download_method_name, mock_stalled_download)
|
||||
watcher = TqdmWatcher()
|
||||
repo_id = "test/repo-stall"
|
||||
filename = "stalled_file.safetensors"
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
hf_hub_download_with_retries(
|
||||
repo_id=repo_id,
|
||||
filename=filename,
|
||||
watcher=watcher,
|
||||
stall_timeout=0.2,
|
||||
retries=2,
|
||||
)
|
||||
|
||||
assert f"Failed to download '{repo_id}/{filename}' after 2 attempts" in str(excinfo.value)
|
||||
|
||||
|
||||
def mock_successful_slow_download(*args, **kwargs):
|
||||
"""A mock for a download that is slow but not stalled."""
|
||||
time.sleep(1)
|
||||
|
||||
return "expected/successful/path"
|
||||
|
||||
|
||||
def _keep_watcher_alive(watcher: TqdmWatcher, stop_event: threading.Event):
|
||||
"""Helper function to run in a thread and periodically tick the watcher."""
|
||||
while not stop_event.is_set():
|
||||
watcher.tick()
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def test_download_progresses_and_succeeds(monkeypatch):
|
||||
"""
|
||||
Verify that a download with periodic progress updates completes successfully.
|
||||
"""
|
||||
monkeypatch.setattr(download_method_name, mock_successful_slow_download)
|
||||
|
||||
watcher = TqdmWatcher()
|
||||
stop_event = threading.Event()
|
||||
ticker_thread = threading.Thread(
|
||||
target=_keep_watcher_alive,
|
||||
args=(watcher, stop_event),
|
||||
daemon=True
|
||||
|
||||
)
|
||||
ticker_thread.start()
|
||||
|
||||
try:
|
||||
result = hf_hub_download_with_retries(
|
||||
repo_id="test/repo-success",
|
||||
filename="good_file.safetensors",
|
||||
stall_timeout=0.3,
|
||||
watcher=watcher
|
||||
)
|
||||
assert result == "expected/successful/path"
|
||||
finally:
|
||||
stop_event.set()
|
||||
ticker_thread.join(timeout=1)
|
||||
@ -74,7 +74,7 @@ def mock_openai_client():
|
||||
instance.images.generate = Mock()
|
||||
yield instance
|
||||
|
||||
|
||||
@pytest.mark.skip("broken transformers")
|
||||
def test_transformers_loader(has_gpu):
|
||||
if not has_gpu:
|
||||
pytest.skip("requires GPU")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user