From 83184916c1ef7038c629970504cf8481cabba2a9 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Wed, 30 Jul 2025 18:28:52 -0700 Subject: [PATCH] Improvements to GGUF - GGUF now works and included 4 bit quants, this will allow WAN to run on 24GB VRAM GPUs - logger only shows full stack for errors - helper functions for colab notebook - fix nvcr.io auth error - lora issues are now reported better - model downloader will use huggingface cache and symlinking if it is supported on your platform - torch compile node now correctly patches the model before compilation --- .github/workflows/docker-build.yml | 5 + comfy/app/colab.py | 138 ++++++++++++++++++++++ comfy/app/logger.py | 2 +- comfy/client/embedded_comfy_client.py | 10 ++ comfy/cmd/execution.py | 35 +++--- comfy/cmd/folder_paths.py | 2 +- comfy/gguf.py | 8 +- comfy/lora.py | 27 +++-- comfy/model_downloader.py | 79 +++++++------ comfy/model_patcher.py | 2 +- comfy/nodes/base_nodes.py | 2 +- comfy/sd.py | 14 +-- comfy_extras/nodes/nodes_torch_compile.py | 26 ++-- comfy_extras/nodes/nodes_train.py | 2 +- 14 files changed, 256 insertions(+), 96 deletions(-) create mode 100644 comfy/app/colab.py diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index 23f525a2c..29aee8d83 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -25,6 +25,11 @@ jobs: registry: ${{ env.REGISTRY }} username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} + - uses: docker/login-action@v3 + with: + registry: nvcr.io + username: "$oauthtoken" + password: ${{ secrets.NVCR_NGC_TOKEN }} - name: Build and push CUDA (NVIDIA) image uses: docker/build-push-action@v6 with: diff --git a/comfy/app/colab.py b/comfy/app/colab.py new file mode 100644 index 000000000..da6cfe717 --- /dev/null +++ b/comfy/app/colab.py @@ -0,0 +1,138 @@ +import asyncio +import os +import stat +import subprocess +import threading +from asyncio import Task +from typing import NamedTuple + +import requests + +from ..cmd.folder_paths import init_default_paths, folder_names_and_paths +# experimental workarounds for colab +from ..cmd.main import _start_comfyui +from ..execution_context import * + + +class _ColabTuple(NamedTuple): + tunnel: "CloudflaredTunnel" + server: Task + + +_colab_instances: list[_ColabTuple] = [] + + +class CloudflaredTunnel: + """ + A class to manage a cloudflared tunnel subprocess. + + Provides methods to start, stop, and manage the lifecycle of the tunnel. + It can be used as a context manager. + """ + + def __init__(self, port: int): + self._port: int = port + self._executable_path: str = "./cloudflared" + self._process: Optional[subprocess.Popen] = None + self._thread: Optional[threading.Thread] = None + + # Download and set permissions for the executable + self._setup_executable() + + # Start the tunnel process and capture the URL + self.url: str = self._start_tunnel() + + def _setup_executable(self): + """Downloads cloudflared and makes it executable if it doesn't exist.""" + if not os.path.exists(self._executable_path): + url = "https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64" + response = requests.get(url, stream=True) + response.raise_for_status() + with open(self._executable_path, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + # Make the file executable (add execute permission for the owner) + current_permissions = os.stat(self._executable_path).st_mode + os.chmod(self._executable_path, current_permissions | stat.S_IEXEC) + + def _start_tunnel(self) -> str: + """Starts the tunnel and returns the public URL.""" + command = [self._executable_path, "tunnel", "--url", f"http://localhost:{self._port}", "--no-autoupdate"] + + # Using DEVNULL for stderr to keep the output clean, stdout is piped + self._process = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + text=True + ) + + for line in iter(self._process.stdout.readline, ""): + if ".trycloudflare.com" in line: + # The line format is typically: "INFO | https://.trycloudflare.com |" + try: + url = line.split("|")[1].strip() + print(f"Tunnel is live at: {url}") + return url + except IndexError: + continue + + # If the loop finishes without finding a URL + self.stop() + raise RuntimeError("Failed to start cloudflared tunnel or find URL.") + + def stop(self): + """Stops the cloudflared tunnel process.""" + if self._process and self._process.poll() is None: + print("Stopping cloudflared tunnel...") + self._process.terminate() + try: + self._process.wait(timeout=5) + print("Tunnel stopped successfully.") + except subprocess.TimeoutExpired: + print("Tunnel did not terminate gracefully, forcing kill.") + self._process.kill() + self._process = None + + def __enter__(self): + """Enter context manager.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Exit context manager, ensuring the tunnel is stopped.""" + self.stop() + + +def start_tunnel(port: int) -> CloudflaredTunnel: + """ + Initializes and starts a cloudflared tunnel. + + Args: + port: The local port number to expose. + + Returns: + A CloudflaredTunnel object that controls the tunnel process. + This object has a `url` attribute and a `stop()` method. + """ + return CloudflaredTunnel(port) + + +def start_server_in_colab() -> str: + """ + returns the URL of the tunnel and the running context + :return: + """ + if len(_colab_instances) == 0: + comfyui_execution_context.set(ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True), custom_nodes=ExportedNodes(), progress_registry=ProgressRegistryStub())) + + async def colab_server_loop(): + init_default_paths(folder_names_and_paths) + await _start_comfyui() + + _loop = asyncio.get_running_loop() + task = _loop.create_task(colab_server_loop()) + + tunnel = start_tunnel(8188) + _colab_instances.append(_ColabTuple(tunnel, task)) + return _colab_instances[0].tunnel.url diff --git a/comfy/app/logger.py b/comfy/app/logger.py index 82482240a..095b174b9 100644 --- a/comfy/app/logger.py +++ b/comfy/app/logger.py @@ -55,7 +55,7 @@ def on_flush(callback): class StackTraceLogger(logging.Logger): def _log(self, level, msg, args, exc_info=None, extra=None, stack_info=False, stacklevel=1): - if level >= logging.WARNING: + if level >= logging.ERROR: stack_info = True super()._log(level, msg, args, exc_info, extra, stack_info, stacklevel=stacklevel + 1) diff --git a/comfy/client/embedded_comfy_client.py b/comfy/client/embedded_comfy_client.py index 98edaa0ad..215e79c6f 100644 --- a/comfy/client/embedded_comfy_client.py +++ b/comfy/client/embedded_comfy_client.py @@ -171,6 +171,16 @@ class Comfy: def task_count(self) -> int: return self._task_count + def __enter__(self): + self._is_running = True + return self + + def __exit__(self, *args): + get_event_loop().run_in_executor(self._executor, _cleanup) + self._executor.shutdown(wait=True) + self._is_running = False + + async def __aenter__(self): self._is_running = True return self diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index b4641d8ec..0f2caa428 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -19,15 +19,14 @@ from typing import List, Optional, Tuple, Literal import torch from opentelemetry.trace import get_current_span, StatusCode, Status -# order matters -from .main_pre import tracer - from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, \ DependencyAwareCache, \ BasicCache from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker from comfy_execution.graph_utils import is_link, GraphBuilder from comfy_execution.utils import CurrentNodeContext +# order matters +from .main_pre import tracer from .. import interruption from .. import model_management from ..cli_args import args @@ -38,7 +37,8 @@ from ..component_model.executor_types import ExecutorToClientProgress, Validatio HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage from ..component_model.files import canonicalize_path from ..component_model.module_property import create_module_properties -from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus, ExecutionStatusAsDict +from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus, \ + ExecutionStatusAsDict from ..execution_context import context_execute_node, context_execute_prompt from ..execution_ext import should_panic_on_exception from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode @@ -48,6 +48,7 @@ from ..progress import get_progress_state, reset_progress_state, add_progress_ha from ..validation import validate_node_input _module_properties = create_module_properties() +logger = logging.getLogger(__name__) @_module_properties.getter @@ -100,12 +101,12 @@ class CacheSet: def __init__(self, cache_type=None, cache_size=None): if cache_type == CacheType.DEPENDENCY_AWARE: self.init_dependency_aware_cache() - logging.info("Disabling intermediate node cache.") + logger.info("Disabling intermediate node cache.") elif cache_type == CacheType.LRU: if cache_size is None: cache_size = 0 self.init_lru_cache(cache_size) - logging.info("Using LRU cache") + logger.info("Using LRU cache") else: self.init_classic_cache() @@ -571,7 +572,7 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra return RecursiveExecutionTuple(ExecutionResult.PENDING, None, None) caches.outputs.set(unique_id, output_data) except interruption.InterruptProcessingException as iex: - logging.info("Processing interrupted") + logger.info("Processing interrupted") # skip formatting inputs/outputs error_details: RecursiveExecutionErrorDetailsInterrupted = { @@ -588,13 +589,13 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra for name, inputs in input_data_all.items(): input_data_formatted[name] = [format_value(x) for x in inputs] - logging.error("An error occurred while executing a workflow", exc_info=ex) - logging.error(traceback.format_exc()) + logger.error("An error occurred while executing a workflow", exc_info=ex) + logger.error(traceback.format_exc()) tips = "" if isinstance(ex, model_management.OOM_EXCEPTION): tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number." - logging.error("Got an OOM, unloading all loaded models.") + logger.error("Got an OOM, unloading all loaded models.") model_management.unload_all_models() error_details: RecursiveExecutionErrorDetails = { @@ -606,7 +607,7 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra } if should_panic_on_exception(ex, args.panic_when): - logging.error(f"The exception {ex} was configured as unrecoverable, scheduling an exit") + logger.error(f"The exception {ex} was configured as unrecoverable, scheduling an exit") def sys_exit(*args): sys.exit(1) @@ -1120,11 +1121,11 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty if valid is True: good_outputs.add(o) else: - logging.error(f"Failed to validate prompt for output {o}:") + logger.error(f"Failed to validate prompt for output {o}:") if len(reasons) > 0: - logging.error("* (prompt):") + logger.error("* (prompt):") for reason in reasons: - logging.error(f" - {reason['message']}: {reason['details']}") + logger.error(f" - {reason['message']}: {reason['details']}") errors += [(o, reasons)] for node_id, result in validated.items(): valid = result[0] @@ -1140,11 +1141,11 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty "dependent_outputs": [], "class_type": class_type } - logging.error(f"* {class_type} {node_id}:") + logger.error(f"* {class_type} {node_id}:") for reason in reasons: - logging.error(f" - {reason['message']}: {reason['details']}") + logger.error(f" - {reason['message']}: {reason['details']}") node_errors[node_id]["dependent_outputs"].append(o) - logging.error("Output will be ignored") + logger.error("Output will be ignored") if len(good_outputs) == 0: errors_list = [] diff --git a/comfy/cmd/folder_paths.py b/comfy/cmd/folder_paths.py index 59832b638..02c496ec2 100644 --- a/comfy/cmd/folder_paths.py +++ b/comfy/cmd/folder_paths.py @@ -93,7 +93,7 @@ def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optio ModelPaths(["checkpoints"], supported_extensions=set(supported_pt_extensions)), ModelPaths(["configs"], additional_absolute_directory_paths=[get_package_as_path("comfy.configs")], supported_extensions={".yaml"}), ModelPaths(["vae"], supported_extensions=set(supported_pt_extensions)), - ModelPaths(folder_names=["clip", "text_encoders"], supported_extensions=set(supported_pt_extensions)), + ModelPaths(folder_names=["text_encoders", "clip"], supported_extensions=set(supported_pt_extensions)), ModelPaths(["loras"], supported_extensions=set(supported_pt_extensions)), ModelPaths(folder_names=["diffusion_models", "unet"], supported_extensions=set(supported_pt_extensions), folder_names_are_relative_directory_paths_too=True), ModelPaths(["clip_vision"], supported_extensions=set(supported_pt_extensions)), diff --git a/comfy/gguf.py b/comfy/gguf.py index d69567532..6a7a12b73 100644 --- a/comfy/gguf.py +++ b/comfy/gguf.py @@ -939,16 +939,16 @@ def get_torch_compiler_disable_decorator(): from packaging import version if not chained_hasattr(torch, "compiler.disable"): - logger.info("ComfyUI-GGUF: Torch too old for torch.compile - bypassing") + logger.debug("ComfyUI-GGUF: Torch too old for torch.compile - bypassing") return dummy_decorator # torch too old elif version.parse(torch.__version__) >= version.parse("2.8"): - logger.info("ComfyUI-GGUF: Allowing full torch compile") + logger.debug("ComfyUI-GGUF: Allowing full torch compile") return dummy_decorator # torch compile works if chained_hasattr(torch, "_dynamo.config.nontraceable_tensor_subclasses"): - logger.info("ComfyUI-GGUF: Allowing full torch compile (nightly)") + logger.debug("ComfyUI-GGUF: Allowing full torch compile (nightly)") return dummy_decorator # torch compile works, nightly before 2.8 release else: - logger.info("ComfyUI-GGUF: Partial torch compile only, consider updating pytorch") + logger.debug("ComfyUI-GGUF: Partial torch compile only, consider updating pytorch") return torch.compiler.disable diff --git a/comfy/lora.py b/comfy/lora.py index 659bb6cff..462eaf1fa 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -24,8 +24,10 @@ import torch from . import model_base from . import model_management from . import utils -from .lora_types import PatchDict, PatchOffset, PatchConversionFunction, PatchType, ModelPatchesDictValue from . import weight_adapter +from .lora_types import PatchDict, PatchOffset, PatchConversionFunction, PatchType, ModelPatchesDictValue + +logger = logging.getLogger(__name__) LORA_CLIP_MAP = { "mlp.fc1": "mlp_fc1", @@ -37,7 +39,7 @@ LORA_CLIP_MAP = { } -def load_lora(lora, to_load, log_missing=True) -> PatchDict: +def load_lora(lora, to_load, log_missing=True, lora_name=None) -> PatchDict: patch_dict: PatchDict = {} loaded_keys = set() for x in to_load: @@ -91,9 +93,10 @@ def load_lora(lora, to_load, log_missing=True) -> PatchDict: loaded_keys.add(set_weight_name) if log_missing: - for x in lora.keys(): - if x not in loaded_keys: - logging.warning("lora key not loaded: {}".format(x)) + not_loaded_keys = [x for x in lora.keys() if x not in loaded_keys] + n_not_loaded_keys = len(not_loaded_keys) + if n_not_loaded_keys > 0: + logger.warning(f"[{lora_name}] lora keys not loaded ({n_not_loaded_keys} / {len(loaded_keys) + n_not_loaded_keys}): {not_loaded_keys}") return patch_dict @@ -293,12 +296,12 @@ def model_lora_keys_unet(model, key_map=None): if k.startswith("diffusion_model."): if k.endswith(".weight"): key_lora = k[len("diffusion_model."):-len(".weight")] - key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format - key_map["transformer.{}".format(key_lora)] = k #SimpleTuner regular format + key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k # SimpleTuner lycoris format + key_map["transformer.{}".format(key_lora)] = k # SimpleTuner regular format if isinstance(model, model_base.ACEStep): for k in sdk: - if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official ACE step lora format + if k.startswith("diffusion_model.") and k.endswith(".weight"): # Official ACE step lora format key_lora = k[len("diffusion_model."):-len(".weight")] key_map["{}".format(key_lora)] = k @@ -364,7 +367,7 @@ def calculate_weight(patches: ModelPatchesDictValue, weight, key, intermediate_d if isinstance(v, weight_adapter.WeightAdapterBase): output = v.calculate_weight(weight, key, strength, strength_model, offset, function, intermediate_dtype, original_weights) if output is None: - logging.warning("Calculate Weight Failed: {} {}".format(v.name, key)) + logger.warning("Calculate Weight Failed: {} {}".format(v.name, key)) else: weight = output if old_weight is not None: @@ -382,12 +385,12 @@ def calculate_weight(patches: ModelPatchesDictValue, weight, key, intermediate_d # An extra flag to pad the weight if the diff's shape is larger than the weight do_pad_weight = len(v) > 1 and v[1]['pad_weight'] if do_pad_weight and diff.shape != weight.shape: - logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape)) + logger.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape)) weight = pad_tensor_to_shape(weight, diff.shape) if strength != 0.0: if diff.shape != weight.shape: - logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape)) + logger.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape)) else: weight += function(strength * model_management.cast_to_device(diff, weight.device, weight.dtype)) elif patch_type == "set": @@ -398,7 +401,7 @@ def calculate_weight(patches: ModelPatchesDictValue, weight, key, intermediate_d model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype) weight += function(strength * model_management.cast_to_device(diff_weight, weight.device, weight.dtype)) else: - logging.warning("patch type not recognized {} {}".format(patch_type, key)) + logger.warning("patch type not recognized {} {}".format(patch_type, key)) if old_weight is not None: weight = old_weight diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index ca1ed34f4..53b38141e 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -1,8 +1,5 @@ from __future__ import annotations -from itertools import chain -from os.path import join - import collections import logging import operator @@ -10,6 +7,8 @@ import os import shutil from collections.abc import Sequence, MutableSequence from functools import reduce +from itertools import chain +from os.path import join from pathlib import Path from typing import List, Optional, Final, Set @@ -66,28 +65,38 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[ filename = canonicalize_path(filename) path = folder_paths.get_full_path(folder_name, filename) + candidate_str_match = False + candidate_filename_match = False + candidate_alternate_filenames_match = False + candidate_save_filename_match = False if path is None and not args.disable_known_models: try: # todo: should this be the first or last path? this_model_directory = folder_paths.get_folder_paths(folder_name)[0] known_file: Optional[HuggingFile | CivitFile] = None for candidate in known_files: - if (canonicalize_path(str(candidate)) == filename - or canonicalize_path(candidate.filename) == filename - or filename in list(map(canonicalize_path, candidate.alternate_filenames)) - or filename == canonicalize_path(candidate.save_with_filename)): + candidate_str_match = canonicalize_path(str(candidate)) == filename + candidate_filename_match = canonicalize_path(candidate.filename) == filename + candidate_alternate_filenames_match = filename in list(map(canonicalize_path, candidate.alternate_filenames)) + candidate_save_filename_match = filename == canonicalize_path(candidate.save_with_filename) + if (candidate_str_match + or candidate_filename_match + or candidate_alternate_filenames_match + or candidate_save_filename_match): known_file = candidate break if known_file is None: + logger.debug(f"get_or_download could not find {filename} in {folder_name}, known_files={known_files}") return path with comfy_tqdm(): if isinstance(known_file, HuggingFile): + symlinks_supported = are_symlinks_supported() if known_file.save_with_filename is not None: linked_filename = known_file.save_with_filename elif not known_file.force_save_in_repo_id and os.path.basename(known_file.filename) != known_file.filename: linked_filename = os.path.basename(known_file.filename) else: - linked_filename = None + linked_filename = known_file.filename if known_file.force_save_in_repo_id or linked_filename is not None and os.path.dirname(known_file.filename) == "": # if the known file has an overridden linked name, save it into a repo_id sub directory @@ -112,8 +121,6 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[ cache_hit = False try: - if not are_symlinks_supported(): - raise PermissionError("no symlink support") # always retrieve this from the cache if it already exists there path = hf_hub_download(repo_id=known_file.repo_id, filename=known_file.filename, @@ -121,19 +128,20 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[ revision=known_file.revision, local_files_only=True, ) - logger.info(f"hf_hub_download cache hit for {known_file.repo_id}/{known_file.filename}") - if linked_filename is None: - linked_filename = known_file.filename + logger.debug(f"hf_hub_download cache hit for {known_file.repo_id}/{known_file.filename}") cache_hit = True - except (LocalEntryNotFoundError, PermissionError): - path = hf_hub_download(repo_id=known_file.repo_id, - filename=known_file.filename, - local_dir=hf_destination_dir, - repo_type=known_file.repo_type, - revision=known_file.revision, - ) + except LocalEntryNotFoundError: + try: + logger.debug(f"{folder_name}/{filename} is being downloaded from {known_file.repo_id}/{known_file.filename} candidate_str_match={candidate_str_match} candidate_filename_match={candidate_filename_match} candidate_alternate_filenames_match={candidate_alternate_filenames_match} candidate_save_filename_match={candidate_save_filename_match}") + path = hf_hub_download(repo_id=known_file.repo_id, + filename=known_file.filename, + repo_type=known_file.repo_type, + revision=known_file.revision, + ) + except IOError as exc_info: + logger.error(f"cannot reach huggingface {known_file.repo_id}/{known_file.filename}", exc_info=exc_info) - if known_file.convert_to_16_bit and file_size is not None and file_size != 0: + if path is not None and known_file.convert_to_16_bit and file_size is not None and file_size != 0: tensors = {} with safe_open(path, framework="pt") as f: with tqdm.tqdm(total=len(f.keys())) as pb: @@ -151,20 +159,23 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[ logger.info(f"Converted {path} to 16 bit, size is {os.stat(path, follow_symlinks=True).st_size}") link_successful = True - if linked_filename is not None: + if path is not None: destination_link = os.path.join(this_model_directory, linked_filename) - try: - os.makedirs(this_model_directory, exist_ok=True) - os.symlink(path, destination_link) - except Exception as exc_info: - logger.error("error while symbolic linking", exc_info=exc_info) + if Path(destination_link).is_file(): + logger.warning(f"{known_file.repo_id}/{known_file.filename} could not link to {destination_link} because the path already exists, which is unexpected") + else: try: - os.link(path, destination_link) - except Exception as hard_link_exc: - logger.error("error while hard linking", exc_info=hard_link_exc) - if cache_hit: - shutil.copyfile(path, destination_link) - link_successful = False + os.makedirs(this_model_directory, exist_ok=True) + os.symlink(path, destination_link) + except Exception as exc_info: + logger.error("error while symbolic linking", exc_info=exc_info) + try: + os.link(path, destination_link) + except Exception as hard_link_exc: + logger.error("error while hard linking", exc_info=hard_link_exc) + if cache_hit: + shutil.copyfile(path, destination_link) + link_successful = False if not link_successful: logger.error(f"Failed to link file with alternative download save name in a way that is compatible with Hugging Face caching {repr(known_file)}. If cache_hit={cache_hit} is True, the file was copied into the destination.", exc_info=exc_info) @@ -558,7 +569,9 @@ KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([ HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_ti2v_5B_fp16.safetensors"), HuggingFile("lodestones/Chroma", "chroma-unlocked-v37.safetensors"), HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "HighNoise/Wan2.2-T2V-A14B-HighNoise-Q8_0.gguf"), + HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "HighNoise/Wan2.2-T2V-A14B-HighNoise-Q4_K_M.gguf"), HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "LowNoise/Wan2.2-T2V-A14B-LowNoise-Q8_0.gguf"), + HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "LowNoise/Wan2.2-T2V-A14B-LowNoise-Q4_K_M.gguf"), ], folder_names=["diffusion_models", "unet"]) KNOWN_CLIP_MODELS: Final[KnownDownloadables] = KnownDownloadables([ diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index c21c532e4..5dc31c8e4 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -630,7 +630,7 @@ class ModelPatcher(ModelManageable): # from gguf if is_quantized(weight): out_weight = weight.to(device_to) - patches = move_patch_to_device(self.patches[key], self.load_device if self.patch_on_device else self.offload_device) + patches = move_patch_to_device(self.patches[key], self.load_device if self.gguf.patch_on_device else self.offload_device) # TODO: do we ever have legitimate duplicate patches? (i.e. patch on top of patched weight) out_weight.patches = [(patches, key)] if inplace_update: diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index 820dbe27e..09b2c5852 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -674,7 +674,7 @@ class LoraLoader: lora = utils.load_torch_file(lora_path, safe_load=True) self.loaded_lora = (lora_path, lora) - model_lora, clip_lora = sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip) + model_lora, clip_lora = sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_name=lora_name) return (model_lora, clip_lora) class LoraLoaderModelOnly(LoraLoader): diff --git a/comfy/sd.py b/comfy/sd.py index afbb102df..2272a30e4 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -15,7 +15,6 @@ import yaml from . import clip_vision from . import diffusers_convert from . import gligen -from . import lora from . import model_detection from . import model_management from . import model_patcher @@ -37,6 +36,7 @@ from .ldm.lightricks.vae import causal_video_autoencoder as lightricks from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine from .ldm.wan import vae as wan_vae from .ldm.wan import vae2_2 as wan_vae2_2 +from .lora import load_lora, model_lora_keys_unet, model_lora_keys_clip from .lora_convert import convert_lora from .model_management import load_models_gpu from .model_patcher import ModelPatcher @@ -64,15 +64,15 @@ from .utils import ProgressBar, FileMetadata logger = logging.getLogger(__name__) -def load_lora_for_models(model, clip, _lora, strength_model, strength_clip): +def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_name=None): key_map = {} if model is not None: - key_map = lora.model_lora_keys_unet(model.model, key_map) + key_map = model_lora_keys_unet(model.model, key_map) if clip is not None: - key_map = lora.model_lora_keys_clip(clip.cond_stage_model, key_map) + key_map = model_lora_keys_clip(clip.cond_stage_model, key_map) - _lora = convert_lora(_lora) - loaded = lora.load_lora(_lora, key_map) + lora = convert_lora(lora) + loaded = load_lora(lora, key_map, lora_name=lora_name) if model is not None: new_modelpatcher: ModelPatcher = model.clone() k = new_modelpatcher.add_patches(loaded, strength_model) @@ -90,7 +90,7 @@ def load_lora_for_models(model, clip, _lora, strength_model, strength_clip): k1 = set(k1) for x in loaded: if (x not in k) and (x not in k1): - logger.warning("NOT LOADED {}".format(x)) + logger.warning(f"[{lora_name}] clip keys not loaded {x}".format(x)) return (new_modelpatcher, new_clip) diff --git a/comfy_extras/nodes/nodes_torch_compile.py b/comfy_extras/nodes/nodes_torch_compile.py index e93a06ce9..2f4403e98 100644 --- a/comfy_extras/nodes/nodes_torch_compile.py +++ b/comfy_extras/nodes/nodes_torch_compile.py @@ -80,7 +80,6 @@ class TorchCompileModel(CustomNode): "backend": backend, "mode": mode, } - move_to_gpu = True try: if backend == "torch_tensorrt": try: @@ -98,7 +97,6 @@ class TorchCompileModel(CustomNode): "enable_weight_streaming": True, "make_refittable": True, } - move_to_gpu = True del compile_kwargs["mode"] if isinstance(model, (ModelPatcher, TransformersManagedModel, VAE)): to_return = model.clone() @@ -109,27 +107,18 @@ class TorchCompileModel(CustomNode): object_patches = ["encoder", "decoder"] else: patcher = to_return - if object_patch is None or len(object_patches) == 0: + if object_patch is None or len(object_patches) == 0 or len(object_patches) == 1 and object_patches[0].strip() == "": object_patches = [DIFFUSION_MODEL] - if move_to_gpu: - model_management.unload_all_models() - model_management.load_models_gpu([patcher]) set_torch_compile_wrapper(patcher, keys=object_patches, **compile_kwargs) - # m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs)) - # todo: do we want to move something back off the GPU? - # if move_to_gpu: - # model_management.unload_all_models() return to_return, elif isinstance(model, torch.nn.Module): - if move_to_gpu: - model_management.unload_all_models() - model.to(device=model_management.get_torch_device()) + model_management.unload_all_models() + model.to(device=model_management.get_torch_device()) res = torch.compile(model=model, **compile_kwargs), - if move_to_gpu: - model.to(device=model_management.unet_offload_device()) + model.to(device=model_management.unet_offload_device()) return res, else: - logger.warning("Encountered a model that cannot be compiled") + logger.warning(f"Encountered a model {model} that cannot be compiled") return model, except OSError as os_error: try: @@ -174,7 +163,8 @@ class QuantizeModel(CustomNode): def execute(self, model: ModelPatcher, strategy: str = _QUANTIZATION_STRATEGIES[0]) -> tuple[ModelPatcher]: model = model.clone() - unet = model.get_model_object("diffusion_model") + model.patch_model(force_patch_weights=True) + unet = model.diffusion_model # todo: quantize quantizes in place, which is not desired # default exclusions @@ -209,7 +199,7 @@ class QuantizeModel(CustomNode): if "autoquant" in strategy: _in_place_fixme = autoquant(unet, error_on_unseen=False) else: - quantize_(unet, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device()) + quantize_(unet, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device(), filter_fn=filter) _in_place_fixme = unet unwrap_tensor_subclass(_in_place_fixme) else: diff --git a/comfy_extras/nodes/nodes_train.py b/comfy_extras/nodes/nodes_train.py index 61aad31e8..d2a161096 100644 --- a/comfy_extras/nodes/nodes_train.py +++ b/comfy_extras/nodes/nodes_train.py @@ -717,7 +717,7 @@ class LoraModelLoader: if strength_model == 0: return (model, ) - model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0) + model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0, None) return (model_lora, )