Improvements to GGUF

- GGUF now works and included 4 bit quants, this will allow WAN to run on 24GB VRAM GPUs
 - logger only shows full stack for errors
 - helper functions for colab notebook
 - fix nvcr.io auth error
 - lora issues are now reported better
 - model downloader will use huggingface cache and symlinking if it is supported on your platform
 - torch compile node now correctly patches the model before compilation
This commit is contained in:
doctorpangloss 2025-07-30 18:28:52 -07:00
parent 69a4906964
commit 83184916c1
14 changed files with 256 additions and 96 deletions

View File

@ -25,6 +25,11 @@ jobs:
registry: ${{ env.REGISTRY }} registry: ${{ env.REGISTRY }}
username: ${{ github.actor }} username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }} password: ${{ secrets.GITHUB_TOKEN }}
- uses: docker/login-action@v3
with:
registry: nvcr.io
username: "$oauthtoken"
password: ${{ secrets.NVCR_NGC_TOKEN }}
- name: Build and push CUDA (NVIDIA) image - name: Build and push CUDA (NVIDIA) image
uses: docker/build-push-action@v6 uses: docker/build-push-action@v6
with: with:

138
comfy/app/colab.py Normal file
View File

@ -0,0 +1,138 @@
import asyncio
import os
import stat
import subprocess
import threading
from asyncio import Task
from typing import NamedTuple
import requests
from ..cmd.folder_paths import init_default_paths, folder_names_and_paths
# experimental workarounds for colab
from ..cmd.main import _start_comfyui
from ..execution_context import *
class _ColabTuple(NamedTuple):
tunnel: "CloudflaredTunnel"
server: Task
_colab_instances: list[_ColabTuple] = []
class CloudflaredTunnel:
"""
A class to manage a cloudflared tunnel subprocess.
Provides methods to start, stop, and manage the lifecycle of the tunnel.
It can be used as a context manager.
"""
def __init__(self, port: int):
self._port: int = port
self._executable_path: str = "./cloudflared"
self._process: Optional[subprocess.Popen] = None
self._thread: Optional[threading.Thread] = None
# Download and set permissions for the executable
self._setup_executable()
# Start the tunnel process and capture the URL
self.url: str = self._start_tunnel()
def _setup_executable(self):
"""Downloads cloudflared and makes it executable if it doesn't exist."""
if not os.path.exists(self._executable_path):
url = "https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64"
response = requests.get(url, stream=True)
response.raise_for_status()
with open(self._executable_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
# Make the file executable (add execute permission for the owner)
current_permissions = os.stat(self._executable_path).st_mode
os.chmod(self._executable_path, current_permissions | stat.S_IEXEC)
def _start_tunnel(self) -> str:
"""Starts the tunnel and returns the public URL."""
command = [self._executable_path, "tunnel", "--url", f"http://localhost:{self._port}", "--no-autoupdate"]
# Using DEVNULL for stderr to keep the output clean, stdout is piped
self._process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
text=True
)
for line in iter(self._process.stdout.readline, ""):
if ".trycloudflare.com" in line:
# The line format is typically: "INFO | https://<subdomain>.trycloudflare.com |"
try:
url = line.split("|")[1].strip()
print(f"Tunnel is live at: {url}")
return url
except IndexError:
continue
# If the loop finishes without finding a URL
self.stop()
raise RuntimeError("Failed to start cloudflared tunnel or find URL.")
def stop(self):
"""Stops the cloudflared tunnel process."""
if self._process and self._process.poll() is None:
print("Stopping cloudflared tunnel...")
self._process.terminate()
try:
self._process.wait(timeout=5)
print("Tunnel stopped successfully.")
except subprocess.TimeoutExpired:
print("Tunnel did not terminate gracefully, forcing kill.")
self._process.kill()
self._process = None
def __enter__(self):
"""Enter context manager."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Exit context manager, ensuring the tunnel is stopped."""
self.stop()
def start_tunnel(port: int) -> CloudflaredTunnel:
"""
Initializes and starts a cloudflared tunnel.
Args:
port: The local port number to expose.
Returns:
A CloudflaredTunnel object that controls the tunnel process.
This object has a `url` attribute and a `stop()` method.
"""
return CloudflaredTunnel(port)
def start_server_in_colab() -> str:
"""
returns the URL of the tunnel and the running context
:return:
"""
if len(_colab_instances) == 0:
comfyui_execution_context.set(ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True), custom_nodes=ExportedNodes(), progress_registry=ProgressRegistryStub()))
async def colab_server_loop():
init_default_paths(folder_names_and_paths)
await _start_comfyui()
_loop = asyncio.get_running_loop()
task = _loop.create_task(colab_server_loop())
tunnel = start_tunnel(8188)
_colab_instances.append(_ColabTuple(tunnel, task))
return _colab_instances[0].tunnel.url

View File

@ -55,7 +55,7 @@ def on_flush(callback):
class StackTraceLogger(logging.Logger): class StackTraceLogger(logging.Logger):
def _log(self, level, msg, args, exc_info=None, extra=None, stack_info=False, stacklevel=1): def _log(self, level, msg, args, exc_info=None, extra=None, stack_info=False, stacklevel=1):
if level >= logging.WARNING: if level >= logging.ERROR:
stack_info = True stack_info = True
super()._log(level, msg, args, exc_info, extra, stack_info, stacklevel=stacklevel + 1) super()._log(level, msg, args, exc_info, extra, stack_info, stacklevel=stacklevel + 1)

View File

@ -171,6 +171,16 @@ class Comfy:
def task_count(self) -> int: def task_count(self) -> int:
return self._task_count return self._task_count
def __enter__(self):
self._is_running = True
return self
def __exit__(self, *args):
get_event_loop().run_in_executor(self._executor, _cleanup)
self._executor.shutdown(wait=True)
self._is_running = False
async def __aenter__(self): async def __aenter__(self):
self._is_running = True self._is_running = True
return self return self

View File

@ -19,15 +19,14 @@ from typing import List, Optional, Tuple, Literal
import torch import torch
from opentelemetry.trace import get_current_span, StatusCode, Status from opentelemetry.trace import get_current_span, StatusCode, Status
# order matters
from .main_pre import tracer
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, \ from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, \
DependencyAwareCache, \ DependencyAwareCache, \
BasicCache BasicCache
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy_execution.graph_utils import is_link, GraphBuilder from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.utils import CurrentNodeContext from comfy_execution.utils import CurrentNodeContext
# order matters
from .main_pre import tracer
from .. import interruption from .. import interruption
from .. import model_management from .. import model_management
from ..cli_args import args from ..cli_args import args
@ -38,7 +37,8 @@ from ..component_model.executor_types import ExecutorToClientProgress, Validatio
HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage
from ..component_model.files import canonicalize_path from ..component_model.files import canonicalize_path
from ..component_model.module_property import create_module_properties from ..component_model.module_property import create_module_properties
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus, ExecutionStatusAsDict from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus, \
ExecutionStatusAsDict
from ..execution_context import context_execute_node, context_execute_prompt from ..execution_context import context_execute_node, context_execute_prompt
from ..execution_ext import should_panic_on_exception from ..execution_ext import should_panic_on_exception
from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
@ -48,6 +48,7 @@ from ..progress import get_progress_state, reset_progress_state, add_progress_ha
from ..validation import validate_node_input from ..validation import validate_node_input
_module_properties = create_module_properties() _module_properties = create_module_properties()
logger = logging.getLogger(__name__)
@_module_properties.getter @_module_properties.getter
@ -100,12 +101,12 @@ class CacheSet:
def __init__(self, cache_type=None, cache_size=None): def __init__(self, cache_type=None, cache_size=None):
if cache_type == CacheType.DEPENDENCY_AWARE: if cache_type == CacheType.DEPENDENCY_AWARE:
self.init_dependency_aware_cache() self.init_dependency_aware_cache()
logging.info("Disabling intermediate node cache.") logger.info("Disabling intermediate node cache.")
elif cache_type == CacheType.LRU: elif cache_type == CacheType.LRU:
if cache_size is None: if cache_size is None:
cache_size = 0 cache_size = 0
self.init_lru_cache(cache_size) self.init_lru_cache(cache_size)
logging.info("Using LRU cache") logger.info("Using LRU cache")
else: else:
self.init_classic_cache() self.init_classic_cache()
@ -571,7 +572,7 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
return RecursiveExecutionTuple(ExecutionResult.PENDING, None, None) return RecursiveExecutionTuple(ExecutionResult.PENDING, None, None)
caches.outputs.set(unique_id, output_data) caches.outputs.set(unique_id, output_data)
except interruption.InterruptProcessingException as iex: except interruption.InterruptProcessingException as iex:
logging.info("Processing interrupted") logger.info("Processing interrupted")
# skip formatting inputs/outputs # skip formatting inputs/outputs
error_details: RecursiveExecutionErrorDetailsInterrupted = { error_details: RecursiveExecutionErrorDetailsInterrupted = {
@ -588,13 +589,13 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
for name, inputs in input_data_all.items(): for name, inputs in input_data_all.items():
input_data_formatted[name] = [format_value(x) for x in inputs] input_data_formatted[name] = [format_value(x) for x in inputs]
logging.error("An error occurred while executing a workflow", exc_info=ex) logger.error("An error occurred while executing a workflow", exc_info=ex)
logging.error(traceback.format_exc()) logger.error(traceback.format_exc())
tips = "" tips = ""
if isinstance(ex, model_management.OOM_EXCEPTION): if isinstance(ex, model_management.OOM_EXCEPTION):
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number." tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
logging.error("Got an OOM, unloading all loaded models.") logger.error("Got an OOM, unloading all loaded models.")
model_management.unload_all_models() model_management.unload_all_models()
error_details: RecursiveExecutionErrorDetails = { error_details: RecursiveExecutionErrorDetails = {
@ -606,7 +607,7 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
} }
if should_panic_on_exception(ex, args.panic_when): if should_panic_on_exception(ex, args.panic_when):
logging.error(f"The exception {ex} was configured as unrecoverable, scheduling an exit") logger.error(f"The exception {ex} was configured as unrecoverable, scheduling an exit")
def sys_exit(*args): def sys_exit(*args):
sys.exit(1) sys.exit(1)
@ -1120,11 +1121,11 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty
if valid is True: if valid is True:
good_outputs.add(o) good_outputs.add(o)
else: else:
logging.error(f"Failed to validate prompt for output {o}:") logger.error(f"Failed to validate prompt for output {o}:")
if len(reasons) > 0: if len(reasons) > 0:
logging.error("* (prompt):") logger.error("* (prompt):")
for reason in reasons: for reason in reasons:
logging.error(f" - {reason['message']}: {reason['details']}") logger.error(f" - {reason['message']}: {reason['details']}")
errors += [(o, reasons)] errors += [(o, reasons)]
for node_id, result in validated.items(): for node_id, result in validated.items():
valid = result[0] valid = result[0]
@ -1140,11 +1141,11 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty
"dependent_outputs": [], "dependent_outputs": [],
"class_type": class_type "class_type": class_type
} }
logging.error(f"* {class_type} {node_id}:") logger.error(f"* {class_type} {node_id}:")
for reason in reasons: for reason in reasons:
logging.error(f" - {reason['message']}: {reason['details']}") logger.error(f" - {reason['message']}: {reason['details']}")
node_errors[node_id]["dependent_outputs"].append(o) node_errors[node_id]["dependent_outputs"].append(o)
logging.error("Output will be ignored") logger.error("Output will be ignored")
if len(good_outputs) == 0: if len(good_outputs) == 0:
errors_list = [] errors_list = []

View File

@ -93,7 +93,7 @@ def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optio
ModelPaths(["checkpoints"], supported_extensions=set(supported_pt_extensions)), ModelPaths(["checkpoints"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["configs"], additional_absolute_directory_paths=[get_package_as_path("comfy.configs")], supported_extensions={".yaml"}), ModelPaths(["configs"], additional_absolute_directory_paths=[get_package_as_path("comfy.configs")], supported_extensions={".yaml"}),
ModelPaths(["vae"], supported_extensions=set(supported_pt_extensions)), ModelPaths(["vae"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(folder_names=["clip", "text_encoders"], supported_extensions=set(supported_pt_extensions)), ModelPaths(folder_names=["text_encoders", "clip"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["loras"], supported_extensions=set(supported_pt_extensions)), ModelPaths(["loras"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(folder_names=["diffusion_models", "unet"], supported_extensions=set(supported_pt_extensions), folder_names_are_relative_directory_paths_too=True), ModelPaths(folder_names=["diffusion_models", "unet"], supported_extensions=set(supported_pt_extensions), folder_names_are_relative_directory_paths_too=True),
ModelPaths(["clip_vision"], supported_extensions=set(supported_pt_extensions)), ModelPaths(["clip_vision"], supported_extensions=set(supported_pt_extensions)),

View File

@ -939,16 +939,16 @@ def get_torch_compiler_disable_decorator():
from packaging import version from packaging import version
if not chained_hasattr(torch, "compiler.disable"): if not chained_hasattr(torch, "compiler.disable"):
logger.info("ComfyUI-GGUF: Torch too old for torch.compile - bypassing") logger.debug("ComfyUI-GGUF: Torch too old for torch.compile - bypassing")
return dummy_decorator # torch too old return dummy_decorator # torch too old
elif version.parse(torch.__version__) >= version.parse("2.8"): elif version.parse(torch.__version__) >= version.parse("2.8"):
logger.info("ComfyUI-GGUF: Allowing full torch compile") logger.debug("ComfyUI-GGUF: Allowing full torch compile")
return dummy_decorator # torch compile works return dummy_decorator # torch compile works
if chained_hasattr(torch, "_dynamo.config.nontraceable_tensor_subclasses"): if chained_hasattr(torch, "_dynamo.config.nontraceable_tensor_subclasses"):
logger.info("ComfyUI-GGUF: Allowing full torch compile (nightly)") logger.debug("ComfyUI-GGUF: Allowing full torch compile (nightly)")
return dummy_decorator # torch compile works, nightly before 2.8 release return dummy_decorator # torch compile works, nightly before 2.8 release
else: else:
logger.info("ComfyUI-GGUF: Partial torch compile only, consider updating pytorch") logger.debug("ComfyUI-GGUF: Partial torch compile only, consider updating pytorch")
return torch.compiler.disable return torch.compiler.disable

View File

@ -24,8 +24,10 @@ import torch
from . import model_base from . import model_base
from . import model_management from . import model_management
from . import utils from . import utils
from .lora_types import PatchDict, PatchOffset, PatchConversionFunction, PatchType, ModelPatchesDictValue
from . import weight_adapter from . import weight_adapter
from .lora_types import PatchDict, PatchOffset, PatchConversionFunction, PatchType, ModelPatchesDictValue
logger = logging.getLogger(__name__)
LORA_CLIP_MAP = { LORA_CLIP_MAP = {
"mlp.fc1": "mlp_fc1", "mlp.fc1": "mlp_fc1",
@ -37,7 +39,7 @@ LORA_CLIP_MAP = {
} }
def load_lora(lora, to_load, log_missing=True) -> PatchDict: def load_lora(lora, to_load, log_missing=True, lora_name=None) -> PatchDict:
patch_dict: PatchDict = {} patch_dict: PatchDict = {}
loaded_keys = set() loaded_keys = set()
for x in to_load: for x in to_load:
@ -91,9 +93,10 @@ def load_lora(lora, to_load, log_missing=True) -> PatchDict:
loaded_keys.add(set_weight_name) loaded_keys.add(set_weight_name)
if log_missing: if log_missing:
for x in lora.keys(): not_loaded_keys = [x for x in lora.keys() if x not in loaded_keys]
if x not in loaded_keys: n_not_loaded_keys = len(not_loaded_keys)
logging.warning("lora key not loaded: {}".format(x)) if n_not_loaded_keys > 0:
logger.warning(f"[{lora_name}] lora keys not loaded ({n_not_loaded_keys} / {len(loaded_keys) + n_not_loaded_keys}): {not_loaded_keys}")
return patch_dict return patch_dict
@ -293,12 +296,12 @@ def model_lora_keys_unet(model, key_map=None):
if k.startswith("diffusion_model."): if k.startswith("diffusion_model."):
if k.endswith(".weight"): if k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")] key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k # SimpleTuner lycoris format
key_map["transformer.{}".format(key_lora)] = k #SimpleTuner regular format key_map["transformer.{}".format(key_lora)] = k # SimpleTuner regular format
if isinstance(model, model_base.ACEStep): if isinstance(model, model_base.ACEStep):
for k in sdk: for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official ACE step lora format if k.startswith("diffusion_model.") and k.endswith(".weight"): # Official ACE step lora format
key_lora = k[len("diffusion_model."):-len(".weight")] key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["{}".format(key_lora)] = k key_map["{}".format(key_lora)] = k
@ -364,7 +367,7 @@ def calculate_weight(patches: ModelPatchesDictValue, weight, key, intermediate_d
if isinstance(v, weight_adapter.WeightAdapterBase): if isinstance(v, weight_adapter.WeightAdapterBase):
output = v.calculate_weight(weight, key, strength, strength_model, offset, function, intermediate_dtype, original_weights) output = v.calculate_weight(weight, key, strength, strength_model, offset, function, intermediate_dtype, original_weights)
if output is None: if output is None:
logging.warning("Calculate Weight Failed: {} {}".format(v.name, key)) logger.warning("Calculate Weight Failed: {} {}".format(v.name, key))
else: else:
weight = output weight = output
if old_weight is not None: if old_weight is not None:
@ -382,12 +385,12 @@ def calculate_weight(patches: ModelPatchesDictValue, weight, key, intermediate_d
# An extra flag to pad the weight if the diff's shape is larger than the weight # An extra flag to pad the weight if the diff's shape is larger than the weight
do_pad_weight = len(v) > 1 and v[1]['pad_weight'] do_pad_weight = len(v) > 1 and v[1]['pad_weight']
if do_pad_weight and diff.shape != weight.shape: if do_pad_weight and diff.shape != weight.shape:
logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape)) logger.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape))
weight = pad_tensor_to_shape(weight, diff.shape) weight = pad_tensor_to_shape(weight, diff.shape)
if strength != 0.0: if strength != 0.0:
if diff.shape != weight.shape: if diff.shape != weight.shape:
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape)) logger.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
else: else:
weight += function(strength * model_management.cast_to_device(diff, weight.device, weight.dtype)) weight += function(strength * model_management.cast_to_device(diff, weight.device, weight.dtype))
elif patch_type == "set": elif patch_type == "set":
@ -398,7 +401,7 @@ def calculate_weight(patches: ModelPatchesDictValue, weight, key, intermediate_d
model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype) model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
weight += function(strength * model_management.cast_to_device(diff_weight, weight.device, weight.dtype)) weight += function(strength * model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
else: else:
logging.warning("patch type not recognized {} {}".format(patch_type, key)) logger.warning("patch type not recognized {} {}".format(patch_type, key))
if old_weight is not None: if old_weight is not None:
weight = old_weight weight = old_weight

View File

@ -1,8 +1,5 @@
from __future__ import annotations from __future__ import annotations
from itertools import chain
from os.path import join
import collections import collections
import logging import logging
import operator import operator
@ -10,6 +7,8 @@ import os
import shutil import shutil
from collections.abc import Sequence, MutableSequence from collections.abc import Sequence, MutableSequence
from functools import reduce from functools import reduce
from itertools import chain
from os.path import join
from pathlib import Path from pathlib import Path
from typing import List, Optional, Final, Set from typing import List, Optional, Final, Set
@ -66,28 +65,38 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
filename = canonicalize_path(filename) filename = canonicalize_path(filename)
path = folder_paths.get_full_path(folder_name, filename) path = folder_paths.get_full_path(folder_name, filename)
candidate_str_match = False
candidate_filename_match = False
candidate_alternate_filenames_match = False
candidate_save_filename_match = False
if path is None and not args.disable_known_models: if path is None and not args.disable_known_models:
try: try:
# todo: should this be the first or last path? # todo: should this be the first or last path?
this_model_directory = folder_paths.get_folder_paths(folder_name)[0] this_model_directory = folder_paths.get_folder_paths(folder_name)[0]
known_file: Optional[HuggingFile | CivitFile] = None known_file: Optional[HuggingFile | CivitFile] = None
for candidate in known_files: for candidate in known_files:
if (canonicalize_path(str(candidate)) == filename candidate_str_match = canonicalize_path(str(candidate)) == filename
or canonicalize_path(candidate.filename) == filename candidate_filename_match = canonicalize_path(candidate.filename) == filename
or filename in list(map(canonicalize_path, candidate.alternate_filenames)) candidate_alternate_filenames_match = filename in list(map(canonicalize_path, candidate.alternate_filenames))
or filename == canonicalize_path(candidate.save_with_filename)): candidate_save_filename_match = filename == canonicalize_path(candidate.save_with_filename)
if (candidate_str_match
or candidate_filename_match
or candidate_alternate_filenames_match
or candidate_save_filename_match):
known_file = candidate known_file = candidate
break break
if known_file is None: if known_file is None:
logger.debug(f"get_or_download could not find {filename} in {folder_name}, known_files={known_files}")
return path return path
with comfy_tqdm(): with comfy_tqdm():
if isinstance(known_file, HuggingFile): if isinstance(known_file, HuggingFile):
symlinks_supported = are_symlinks_supported()
if known_file.save_with_filename is not None: if known_file.save_with_filename is not None:
linked_filename = known_file.save_with_filename linked_filename = known_file.save_with_filename
elif not known_file.force_save_in_repo_id and os.path.basename(known_file.filename) != known_file.filename: elif not known_file.force_save_in_repo_id and os.path.basename(known_file.filename) != known_file.filename:
linked_filename = os.path.basename(known_file.filename) linked_filename = os.path.basename(known_file.filename)
else: else:
linked_filename = None linked_filename = known_file.filename
if known_file.force_save_in_repo_id or linked_filename is not None and os.path.dirname(known_file.filename) == "": if known_file.force_save_in_repo_id or linked_filename is not None and os.path.dirname(known_file.filename) == "":
# if the known file has an overridden linked name, save it into a repo_id sub directory # if the known file has an overridden linked name, save it into a repo_id sub directory
@ -112,8 +121,6 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
cache_hit = False cache_hit = False
try: try:
if not are_symlinks_supported():
raise PermissionError("no symlink support")
# always retrieve this from the cache if it already exists there # always retrieve this from the cache if it already exists there
path = hf_hub_download(repo_id=known_file.repo_id, path = hf_hub_download(repo_id=known_file.repo_id,
filename=known_file.filename, filename=known_file.filename,
@ -121,19 +128,20 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
revision=known_file.revision, revision=known_file.revision,
local_files_only=True, local_files_only=True,
) )
logger.info(f"hf_hub_download cache hit for {known_file.repo_id}/{known_file.filename}") logger.debug(f"hf_hub_download cache hit for {known_file.repo_id}/{known_file.filename}")
if linked_filename is None:
linked_filename = known_file.filename
cache_hit = True cache_hit = True
except (LocalEntryNotFoundError, PermissionError): except LocalEntryNotFoundError:
path = hf_hub_download(repo_id=known_file.repo_id, try:
filename=known_file.filename, logger.debug(f"{folder_name}/{filename} is being downloaded from {known_file.repo_id}/{known_file.filename} candidate_str_match={candidate_str_match} candidate_filename_match={candidate_filename_match} candidate_alternate_filenames_match={candidate_alternate_filenames_match} candidate_save_filename_match={candidate_save_filename_match}")
local_dir=hf_destination_dir, path = hf_hub_download(repo_id=known_file.repo_id,
repo_type=known_file.repo_type, filename=known_file.filename,
revision=known_file.revision, repo_type=known_file.repo_type,
) revision=known_file.revision,
)
except IOError as exc_info:
logger.error(f"cannot reach huggingface {known_file.repo_id}/{known_file.filename}", exc_info=exc_info)
if known_file.convert_to_16_bit and file_size is not None and file_size != 0: if path is not None and known_file.convert_to_16_bit and file_size is not None and file_size != 0:
tensors = {} tensors = {}
with safe_open(path, framework="pt") as f: with safe_open(path, framework="pt") as f:
with tqdm.tqdm(total=len(f.keys())) as pb: with tqdm.tqdm(total=len(f.keys())) as pb:
@ -151,20 +159,23 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
logger.info(f"Converted {path} to 16 bit, size is {os.stat(path, follow_symlinks=True).st_size}") logger.info(f"Converted {path} to 16 bit, size is {os.stat(path, follow_symlinks=True).st_size}")
link_successful = True link_successful = True
if linked_filename is not None: if path is not None:
destination_link = os.path.join(this_model_directory, linked_filename) destination_link = os.path.join(this_model_directory, linked_filename)
try: if Path(destination_link).is_file():
os.makedirs(this_model_directory, exist_ok=True) logger.warning(f"{known_file.repo_id}/{known_file.filename} could not link to {destination_link} because the path already exists, which is unexpected")
os.symlink(path, destination_link) else:
except Exception as exc_info:
logger.error("error while symbolic linking", exc_info=exc_info)
try: try:
os.link(path, destination_link) os.makedirs(this_model_directory, exist_ok=True)
except Exception as hard_link_exc: os.symlink(path, destination_link)
logger.error("error while hard linking", exc_info=hard_link_exc) except Exception as exc_info:
if cache_hit: logger.error("error while symbolic linking", exc_info=exc_info)
shutil.copyfile(path, destination_link) try:
link_successful = False os.link(path, destination_link)
except Exception as hard_link_exc:
logger.error("error while hard linking", exc_info=hard_link_exc)
if cache_hit:
shutil.copyfile(path, destination_link)
link_successful = False
if not link_successful: if not link_successful:
logger.error(f"Failed to link file with alternative download save name in a way that is compatible with Hugging Face caching {repr(known_file)}. If cache_hit={cache_hit} is True, the file was copied into the destination.", exc_info=exc_info) logger.error(f"Failed to link file with alternative download save name in a way that is compatible with Hugging Face caching {repr(known_file)}. If cache_hit={cache_hit} is True, the file was copied into the destination.", exc_info=exc_info)
@ -558,7 +569,9 @@ KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_ti2v_5B_fp16.safetensors"), HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_ti2v_5B_fp16.safetensors"),
HuggingFile("lodestones/Chroma", "chroma-unlocked-v37.safetensors"), HuggingFile("lodestones/Chroma", "chroma-unlocked-v37.safetensors"),
HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "HighNoise/Wan2.2-T2V-A14B-HighNoise-Q8_0.gguf"), HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "HighNoise/Wan2.2-T2V-A14B-HighNoise-Q8_0.gguf"),
HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "HighNoise/Wan2.2-T2V-A14B-HighNoise-Q4_K_M.gguf"),
HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "LowNoise/Wan2.2-T2V-A14B-LowNoise-Q8_0.gguf"), HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "LowNoise/Wan2.2-T2V-A14B-LowNoise-Q8_0.gguf"),
HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "LowNoise/Wan2.2-T2V-A14B-LowNoise-Q4_K_M.gguf"),
], folder_names=["diffusion_models", "unet"]) ], folder_names=["diffusion_models", "unet"])
KNOWN_CLIP_MODELS: Final[KnownDownloadables] = KnownDownloadables([ KNOWN_CLIP_MODELS: Final[KnownDownloadables] = KnownDownloadables([

View File

@ -630,7 +630,7 @@ class ModelPatcher(ModelManageable):
# from gguf # from gguf
if is_quantized(weight): if is_quantized(weight):
out_weight = weight.to(device_to) out_weight = weight.to(device_to)
patches = move_patch_to_device(self.patches[key], self.load_device if self.patch_on_device else self.offload_device) patches = move_patch_to_device(self.patches[key], self.load_device if self.gguf.patch_on_device else self.offload_device)
# TODO: do we ever have legitimate duplicate patches? (i.e. patch on top of patched weight) # TODO: do we ever have legitimate duplicate patches? (i.e. patch on top of patched weight)
out_weight.patches = [(patches, key)] out_weight.patches = [(patches, key)]
if inplace_update: if inplace_update:

View File

@ -674,7 +674,7 @@ class LoraLoader:
lora = utils.load_torch_file(lora_path, safe_load=True) lora = utils.load_torch_file(lora_path, safe_load=True)
self.loaded_lora = (lora_path, lora) self.loaded_lora = (lora_path, lora)
model_lora, clip_lora = sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip) model_lora, clip_lora = sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_name=lora_name)
return (model_lora, clip_lora) return (model_lora, clip_lora)
class LoraLoaderModelOnly(LoraLoader): class LoraLoaderModelOnly(LoraLoader):

View File

@ -15,7 +15,6 @@ import yaml
from . import clip_vision from . import clip_vision
from . import diffusers_convert from . import diffusers_convert
from . import gligen from . import gligen
from . import lora
from . import model_detection from . import model_detection
from . import model_management from . import model_management
from . import model_patcher from . import model_patcher
@ -37,6 +36,7 @@ from .ldm.lightricks.vae import causal_video_autoencoder as lightricks
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
from .ldm.wan import vae as wan_vae from .ldm.wan import vae as wan_vae
from .ldm.wan import vae2_2 as wan_vae2_2 from .ldm.wan import vae2_2 as wan_vae2_2
from .lora import load_lora, model_lora_keys_unet, model_lora_keys_clip
from .lora_convert import convert_lora from .lora_convert import convert_lora
from .model_management import load_models_gpu from .model_management import load_models_gpu
from .model_patcher import ModelPatcher from .model_patcher import ModelPatcher
@ -64,15 +64,15 @@ from .utils import ProgressBar, FileMetadata
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def load_lora_for_models(model, clip, _lora, strength_model, strength_clip): def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_name=None):
key_map = {} key_map = {}
if model is not None: if model is not None:
key_map = lora.model_lora_keys_unet(model.model, key_map) key_map = model_lora_keys_unet(model.model, key_map)
if clip is not None: if clip is not None:
key_map = lora.model_lora_keys_clip(clip.cond_stage_model, key_map) key_map = model_lora_keys_clip(clip.cond_stage_model, key_map)
_lora = convert_lora(_lora) lora = convert_lora(lora)
loaded = lora.load_lora(_lora, key_map) loaded = load_lora(lora, key_map, lora_name=lora_name)
if model is not None: if model is not None:
new_modelpatcher: ModelPatcher = model.clone() new_modelpatcher: ModelPatcher = model.clone()
k = new_modelpatcher.add_patches(loaded, strength_model) k = new_modelpatcher.add_patches(loaded, strength_model)
@ -90,7 +90,7 @@ def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
k1 = set(k1) k1 = set(k1)
for x in loaded: for x in loaded:
if (x not in k) and (x not in k1): if (x not in k) and (x not in k1):
logger.warning("NOT LOADED {}".format(x)) logger.warning(f"[{lora_name}] clip keys not loaded {x}".format(x))
return (new_modelpatcher, new_clip) return (new_modelpatcher, new_clip)

View File

@ -80,7 +80,6 @@ class TorchCompileModel(CustomNode):
"backend": backend, "backend": backend,
"mode": mode, "mode": mode,
} }
move_to_gpu = True
try: try:
if backend == "torch_tensorrt": if backend == "torch_tensorrt":
try: try:
@ -98,7 +97,6 @@ class TorchCompileModel(CustomNode):
"enable_weight_streaming": True, "enable_weight_streaming": True,
"make_refittable": True, "make_refittable": True,
} }
move_to_gpu = True
del compile_kwargs["mode"] del compile_kwargs["mode"]
if isinstance(model, (ModelPatcher, TransformersManagedModel, VAE)): if isinstance(model, (ModelPatcher, TransformersManagedModel, VAE)):
to_return = model.clone() to_return = model.clone()
@ -109,27 +107,18 @@ class TorchCompileModel(CustomNode):
object_patches = ["encoder", "decoder"] object_patches = ["encoder", "decoder"]
else: else:
patcher = to_return patcher = to_return
if object_patch is None or len(object_patches) == 0: if object_patch is None or len(object_patches) == 0 or len(object_patches) == 1 and object_patches[0].strip() == "":
object_patches = [DIFFUSION_MODEL] object_patches = [DIFFUSION_MODEL]
if move_to_gpu:
model_management.unload_all_models()
model_management.load_models_gpu([patcher])
set_torch_compile_wrapper(patcher, keys=object_patches, **compile_kwargs) set_torch_compile_wrapper(patcher, keys=object_patches, **compile_kwargs)
# m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs))
# todo: do we want to move something back off the GPU?
# if move_to_gpu:
# model_management.unload_all_models()
return to_return, return to_return,
elif isinstance(model, torch.nn.Module): elif isinstance(model, torch.nn.Module):
if move_to_gpu: model_management.unload_all_models()
model_management.unload_all_models() model.to(device=model_management.get_torch_device())
model.to(device=model_management.get_torch_device())
res = torch.compile(model=model, **compile_kwargs), res = torch.compile(model=model, **compile_kwargs),
if move_to_gpu: model.to(device=model_management.unet_offload_device())
model.to(device=model_management.unet_offload_device())
return res, return res,
else: else:
logger.warning("Encountered a model that cannot be compiled") logger.warning(f"Encountered a model {model} that cannot be compiled")
return model, return model,
except OSError as os_error: except OSError as os_error:
try: try:
@ -174,7 +163,8 @@ class QuantizeModel(CustomNode):
def execute(self, model: ModelPatcher, strategy: str = _QUANTIZATION_STRATEGIES[0]) -> tuple[ModelPatcher]: def execute(self, model: ModelPatcher, strategy: str = _QUANTIZATION_STRATEGIES[0]) -> tuple[ModelPatcher]:
model = model.clone() model = model.clone()
unet = model.get_model_object("diffusion_model") model.patch_model(force_patch_weights=True)
unet = model.diffusion_model
# todo: quantize quantizes in place, which is not desired # todo: quantize quantizes in place, which is not desired
# default exclusions # default exclusions
@ -209,7 +199,7 @@ class QuantizeModel(CustomNode):
if "autoquant" in strategy: if "autoquant" in strategy:
_in_place_fixme = autoquant(unet, error_on_unseen=False) _in_place_fixme = autoquant(unet, error_on_unseen=False)
else: else:
quantize_(unet, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device()) quantize_(unet, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device(), filter_fn=filter)
_in_place_fixme = unet _in_place_fixme = unet
unwrap_tensor_subclass(_in_place_fixme) unwrap_tensor_subclass(_in_place_fixme)
else: else:

View File

@ -717,7 +717,7 @@ class LoraModelLoader:
if strength_model == 0: if strength_model == 0:
return (model, ) return (model, )
model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0) model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0, None)
return (model_lora, ) return (model_lora, )