Improvements to GGUF

- GGUF now works and included 4 bit quants, this will allow WAN to run on 24GB VRAM GPUs
 - logger only shows full stack for errors
 - helper functions for colab notebook
 - fix nvcr.io auth error
 - lora issues are now reported better
 - model downloader will use huggingface cache and symlinking if it is supported on your platform
 - torch compile node now correctly patches the model before compilation
This commit is contained in:
doctorpangloss 2025-07-30 18:28:52 -07:00
parent 69a4906964
commit 83184916c1
14 changed files with 256 additions and 96 deletions

View File

@ -25,6 +25,11 @@ jobs:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- uses: docker/login-action@v3
with:
registry: nvcr.io
username: "$oauthtoken"
password: ${{ secrets.NVCR_NGC_TOKEN }}
- name: Build and push CUDA (NVIDIA) image
uses: docker/build-push-action@v6
with:

138
comfy/app/colab.py Normal file
View File

@ -0,0 +1,138 @@
import asyncio
import os
import stat
import subprocess
import threading
from asyncio import Task
from typing import NamedTuple
import requests
from ..cmd.folder_paths import init_default_paths, folder_names_and_paths
# experimental workarounds for colab
from ..cmd.main import _start_comfyui
from ..execution_context import *
class _ColabTuple(NamedTuple):
tunnel: "CloudflaredTunnel"
server: Task
_colab_instances: list[_ColabTuple] = []
class CloudflaredTunnel:
"""
A class to manage a cloudflared tunnel subprocess.
Provides methods to start, stop, and manage the lifecycle of the tunnel.
It can be used as a context manager.
"""
def __init__(self, port: int):
self._port: int = port
self._executable_path: str = "./cloudflared"
self._process: Optional[subprocess.Popen] = None
self._thread: Optional[threading.Thread] = None
# Download and set permissions for the executable
self._setup_executable()
# Start the tunnel process and capture the URL
self.url: str = self._start_tunnel()
def _setup_executable(self):
"""Downloads cloudflared and makes it executable if it doesn't exist."""
if not os.path.exists(self._executable_path):
url = "https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64"
response = requests.get(url, stream=True)
response.raise_for_status()
with open(self._executable_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
# Make the file executable (add execute permission for the owner)
current_permissions = os.stat(self._executable_path).st_mode
os.chmod(self._executable_path, current_permissions | stat.S_IEXEC)
def _start_tunnel(self) -> str:
"""Starts the tunnel and returns the public URL."""
command = [self._executable_path, "tunnel", "--url", f"http://localhost:{self._port}", "--no-autoupdate"]
# Using DEVNULL for stderr to keep the output clean, stdout is piped
self._process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
text=True
)
for line in iter(self._process.stdout.readline, ""):
if ".trycloudflare.com" in line:
# The line format is typically: "INFO | https://<subdomain>.trycloudflare.com |"
try:
url = line.split("|")[1].strip()
print(f"Tunnel is live at: {url}")
return url
except IndexError:
continue
# If the loop finishes without finding a URL
self.stop()
raise RuntimeError("Failed to start cloudflared tunnel or find URL.")
def stop(self):
"""Stops the cloudflared tunnel process."""
if self._process and self._process.poll() is None:
print("Stopping cloudflared tunnel...")
self._process.terminate()
try:
self._process.wait(timeout=5)
print("Tunnel stopped successfully.")
except subprocess.TimeoutExpired:
print("Tunnel did not terminate gracefully, forcing kill.")
self._process.kill()
self._process = None
def __enter__(self):
"""Enter context manager."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Exit context manager, ensuring the tunnel is stopped."""
self.stop()
def start_tunnel(port: int) -> CloudflaredTunnel:
"""
Initializes and starts a cloudflared tunnel.
Args:
port: The local port number to expose.
Returns:
A CloudflaredTunnel object that controls the tunnel process.
This object has a `url` attribute and a `stop()` method.
"""
return CloudflaredTunnel(port)
def start_server_in_colab() -> str:
"""
returns the URL of the tunnel and the running context
:return:
"""
if len(_colab_instances) == 0:
comfyui_execution_context.set(ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True), custom_nodes=ExportedNodes(), progress_registry=ProgressRegistryStub()))
async def colab_server_loop():
init_default_paths(folder_names_and_paths)
await _start_comfyui()
_loop = asyncio.get_running_loop()
task = _loop.create_task(colab_server_loop())
tunnel = start_tunnel(8188)
_colab_instances.append(_ColabTuple(tunnel, task))
return _colab_instances[0].tunnel.url

View File

@ -55,7 +55,7 @@ def on_flush(callback):
class StackTraceLogger(logging.Logger):
def _log(self, level, msg, args, exc_info=None, extra=None, stack_info=False, stacklevel=1):
if level >= logging.WARNING:
if level >= logging.ERROR:
stack_info = True
super()._log(level, msg, args, exc_info, extra, stack_info, stacklevel=stacklevel + 1)

View File

@ -171,6 +171,16 @@ class Comfy:
def task_count(self) -> int:
return self._task_count
def __enter__(self):
self._is_running = True
return self
def __exit__(self, *args):
get_event_loop().run_in_executor(self._executor, _cleanup)
self._executor.shutdown(wait=True)
self._is_running = False
async def __aenter__(self):
self._is_running = True
return self

View File

@ -19,15 +19,14 @@ from typing import List, Optional, Tuple, Literal
import torch
from opentelemetry.trace import get_current_span, StatusCode, Status
# order matters
from .main_pre import tracer
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, \
DependencyAwareCache, \
BasicCache
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.utils import CurrentNodeContext
# order matters
from .main_pre import tracer
from .. import interruption
from .. import model_management
from ..cli_args import args
@ -38,7 +37,8 @@ from ..component_model.executor_types import ExecutorToClientProgress, Validatio
HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage
from ..component_model.files import canonicalize_path
from ..component_model.module_property import create_module_properties
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus, ExecutionStatusAsDict
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus, \
ExecutionStatusAsDict
from ..execution_context import context_execute_node, context_execute_prompt
from ..execution_ext import should_panic_on_exception
from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
@ -48,6 +48,7 @@ from ..progress import get_progress_state, reset_progress_state, add_progress_ha
from ..validation import validate_node_input
_module_properties = create_module_properties()
logger = logging.getLogger(__name__)
@_module_properties.getter
@ -100,12 +101,12 @@ class CacheSet:
def __init__(self, cache_type=None, cache_size=None):
if cache_type == CacheType.DEPENDENCY_AWARE:
self.init_dependency_aware_cache()
logging.info("Disabling intermediate node cache.")
logger.info("Disabling intermediate node cache.")
elif cache_type == CacheType.LRU:
if cache_size is None:
cache_size = 0
self.init_lru_cache(cache_size)
logging.info("Using LRU cache")
logger.info("Using LRU cache")
else:
self.init_classic_cache()
@ -571,7 +572,7 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
return RecursiveExecutionTuple(ExecutionResult.PENDING, None, None)
caches.outputs.set(unique_id, output_data)
except interruption.InterruptProcessingException as iex:
logging.info("Processing interrupted")
logger.info("Processing interrupted")
# skip formatting inputs/outputs
error_details: RecursiveExecutionErrorDetailsInterrupted = {
@ -588,13 +589,13 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
for name, inputs in input_data_all.items():
input_data_formatted[name] = [format_value(x) for x in inputs]
logging.error("An error occurred while executing a workflow", exc_info=ex)
logging.error(traceback.format_exc())
logger.error("An error occurred while executing a workflow", exc_info=ex)
logger.error(traceback.format_exc())
tips = ""
if isinstance(ex, model_management.OOM_EXCEPTION):
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
logging.error("Got an OOM, unloading all loaded models.")
logger.error("Got an OOM, unloading all loaded models.")
model_management.unload_all_models()
error_details: RecursiveExecutionErrorDetails = {
@ -606,7 +607,7 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
}
if should_panic_on_exception(ex, args.panic_when):
logging.error(f"The exception {ex} was configured as unrecoverable, scheduling an exit")
logger.error(f"The exception {ex} was configured as unrecoverable, scheduling an exit")
def sys_exit(*args):
sys.exit(1)
@ -1120,11 +1121,11 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty
if valid is True:
good_outputs.add(o)
else:
logging.error(f"Failed to validate prompt for output {o}:")
logger.error(f"Failed to validate prompt for output {o}:")
if len(reasons) > 0:
logging.error("* (prompt):")
logger.error("* (prompt):")
for reason in reasons:
logging.error(f" - {reason['message']}: {reason['details']}")
logger.error(f" - {reason['message']}: {reason['details']}")
errors += [(o, reasons)]
for node_id, result in validated.items():
valid = result[0]
@ -1140,11 +1141,11 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty
"dependent_outputs": [],
"class_type": class_type
}
logging.error(f"* {class_type} {node_id}:")
logger.error(f"* {class_type} {node_id}:")
for reason in reasons:
logging.error(f" - {reason['message']}: {reason['details']}")
logger.error(f" - {reason['message']}: {reason['details']}")
node_errors[node_id]["dependent_outputs"].append(o)
logging.error("Output will be ignored")
logger.error("Output will be ignored")
if len(good_outputs) == 0:
errors_list = []

View File

@ -93,7 +93,7 @@ def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optio
ModelPaths(["checkpoints"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["configs"], additional_absolute_directory_paths=[get_package_as_path("comfy.configs")], supported_extensions={".yaml"}),
ModelPaths(["vae"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(folder_names=["clip", "text_encoders"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(folder_names=["text_encoders", "clip"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["loras"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(folder_names=["diffusion_models", "unet"], supported_extensions=set(supported_pt_extensions), folder_names_are_relative_directory_paths_too=True),
ModelPaths(["clip_vision"], supported_extensions=set(supported_pt_extensions)),

View File

@ -939,16 +939,16 @@ def get_torch_compiler_disable_decorator():
from packaging import version
if not chained_hasattr(torch, "compiler.disable"):
logger.info("ComfyUI-GGUF: Torch too old for torch.compile - bypassing")
logger.debug("ComfyUI-GGUF: Torch too old for torch.compile - bypassing")
return dummy_decorator # torch too old
elif version.parse(torch.__version__) >= version.parse("2.8"):
logger.info("ComfyUI-GGUF: Allowing full torch compile")
logger.debug("ComfyUI-GGUF: Allowing full torch compile")
return dummy_decorator # torch compile works
if chained_hasattr(torch, "_dynamo.config.nontraceable_tensor_subclasses"):
logger.info("ComfyUI-GGUF: Allowing full torch compile (nightly)")
logger.debug("ComfyUI-GGUF: Allowing full torch compile (nightly)")
return dummy_decorator # torch compile works, nightly before 2.8 release
else:
logger.info("ComfyUI-GGUF: Partial torch compile only, consider updating pytorch")
logger.debug("ComfyUI-GGUF: Partial torch compile only, consider updating pytorch")
return torch.compiler.disable

View File

@ -24,8 +24,10 @@ import torch
from . import model_base
from . import model_management
from . import utils
from .lora_types import PatchDict, PatchOffset, PatchConversionFunction, PatchType, ModelPatchesDictValue
from . import weight_adapter
from .lora_types import PatchDict, PatchOffset, PatchConversionFunction, PatchType, ModelPatchesDictValue
logger = logging.getLogger(__name__)
LORA_CLIP_MAP = {
"mlp.fc1": "mlp_fc1",
@ -37,7 +39,7 @@ LORA_CLIP_MAP = {
}
def load_lora(lora, to_load, log_missing=True) -> PatchDict:
def load_lora(lora, to_load, log_missing=True, lora_name=None) -> PatchDict:
patch_dict: PatchDict = {}
loaded_keys = set()
for x in to_load:
@ -91,9 +93,10 @@ def load_lora(lora, to_load, log_missing=True) -> PatchDict:
loaded_keys.add(set_weight_name)
if log_missing:
for x in lora.keys():
if x not in loaded_keys:
logging.warning("lora key not loaded: {}".format(x))
not_loaded_keys = [x for x in lora.keys() if x not in loaded_keys]
n_not_loaded_keys = len(not_loaded_keys)
if n_not_loaded_keys > 0:
logger.warning(f"[{lora_name}] lora keys not loaded ({n_not_loaded_keys} / {len(loaded_keys) + n_not_loaded_keys}): {not_loaded_keys}")
return patch_dict
@ -293,12 +296,12 @@ def model_lora_keys_unet(model, key_map=None):
if k.startswith("diffusion_model."):
if k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
key_map["transformer.{}".format(key_lora)] = k #SimpleTuner regular format
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k # SimpleTuner lycoris format
key_map["transformer.{}".format(key_lora)] = k # SimpleTuner regular format
if isinstance(model, model_base.ACEStep):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official ACE step lora format
if k.startswith("diffusion_model.") and k.endswith(".weight"): # Official ACE step lora format
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["{}".format(key_lora)] = k
@ -364,7 +367,7 @@ def calculate_weight(patches: ModelPatchesDictValue, weight, key, intermediate_d
if isinstance(v, weight_adapter.WeightAdapterBase):
output = v.calculate_weight(weight, key, strength, strength_model, offset, function, intermediate_dtype, original_weights)
if output is None:
logging.warning("Calculate Weight Failed: {} {}".format(v.name, key))
logger.warning("Calculate Weight Failed: {} {}".format(v.name, key))
else:
weight = output
if old_weight is not None:
@ -382,12 +385,12 @@ def calculate_weight(patches: ModelPatchesDictValue, weight, key, intermediate_d
# An extra flag to pad the weight if the diff's shape is larger than the weight
do_pad_weight = len(v) > 1 and v[1]['pad_weight']
if do_pad_weight and diff.shape != weight.shape:
logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape))
logger.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape))
weight = pad_tensor_to_shape(weight, diff.shape)
if strength != 0.0:
if diff.shape != weight.shape:
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
logger.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
else:
weight += function(strength * model_management.cast_to_device(diff, weight.device, weight.dtype))
elif patch_type == "set":
@ -398,7 +401,7 @@ def calculate_weight(patches: ModelPatchesDictValue, weight, key, intermediate_d
model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
weight += function(strength * model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
else:
logging.warning("patch type not recognized {} {}".format(patch_type, key))
logger.warning("patch type not recognized {} {}".format(patch_type, key))
if old_weight is not None:
weight = old_weight

View File

@ -1,8 +1,5 @@
from __future__ import annotations
from itertools import chain
from os.path import join
import collections
import logging
import operator
@ -10,6 +7,8 @@ import os
import shutil
from collections.abc import Sequence, MutableSequence
from functools import reduce
from itertools import chain
from os.path import join
from pathlib import Path
from typing import List, Optional, Final, Set
@ -66,28 +65,38 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
filename = canonicalize_path(filename)
path = folder_paths.get_full_path(folder_name, filename)
candidate_str_match = False
candidate_filename_match = False
candidate_alternate_filenames_match = False
candidate_save_filename_match = False
if path is None and not args.disable_known_models:
try:
# todo: should this be the first or last path?
this_model_directory = folder_paths.get_folder_paths(folder_name)[0]
known_file: Optional[HuggingFile | CivitFile] = None
for candidate in known_files:
if (canonicalize_path(str(candidate)) == filename
or canonicalize_path(candidate.filename) == filename
or filename in list(map(canonicalize_path, candidate.alternate_filenames))
or filename == canonicalize_path(candidate.save_with_filename)):
candidate_str_match = canonicalize_path(str(candidate)) == filename
candidate_filename_match = canonicalize_path(candidate.filename) == filename
candidate_alternate_filenames_match = filename in list(map(canonicalize_path, candidate.alternate_filenames))
candidate_save_filename_match = filename == canonicalize_path(candidate.save_with_filename)
if (candidate_str_match
or candidate_filename_match
or candidate_alternate_filenames_match
or candidate_save_filename_match):
known_file = candidate
break
if known_file is None:
logger.debug(f"get_or_download could not find {filename} in {folder_name}, known_files={known_files}")
return path
with comfy_tqdm():
if isinstance(known_file, HuggingFile):
symlinks_supported = are_symlinks_supported()
if known_file.save_with_filename is not None:
linked_filename = known_file.save_with_filename
elif not known_file.force_save_in_repo_id and os.path.basename(known_file.filename) != known_file.filename:
linked_filename = os.path.basename(known_file.filename)
else:
linked_filename = None
linked_filename = known_file.filename
if known_file.force_save_in_repo_id or linked_filename is not None and os.path.dirname(known_file.filename) == "":
# if the known file has an overridden linked name, save it into a repo_id sub directory
@ -112,8 +121,6 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
cache_hit = False
try:
if not are_symlinks_supported():
raise PermissionError("no symlink support")
# always retrieve this from the cache if it already exists there
path = hf_hub_download(repo_id=known_file.repo_id,
filename=known_file.filename,
@ -121,19 +128,20 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
revision=known_file.revision,
local_files_only=True,
)
logger.info(f"hf_hub_download cache hit for {known_file.repo_id}/{known_file.filename}")
if linked_filename is None:
linked_filename = known_file.filename
logger.debug(f"hf_hub_download cache hit for {known_file.repo_id}/{known_file.filename}")
cache_hit = True
except (LocalEntryNotFoundError, PermissionError):
path = hf_hub_download(repo_id=known_file.repo_id,
filename=known_file.filename,
local_dir=hf_destination_dir,
repo_type=known_file.repo_type,
revision=known_file.revision,
)
except LocalEntryNotFoundError:
try:
logger.debug(f"{folder_name}/{filename} is being downloaded from {known_file.repo_id}/{known_file.filename} candidate_str_match={candidate_str_match} candidate_filename_match={candidate_filename_match} candidate_alternate_filenames_match={candidate_alternate_filenames_match} candidate_save_filename_match={candidate_save_filename_match}")
path = hf_hub_download(repo_id=known_file.repo_id,
filename=known_file.filename,
repo_type=known_file.repo_type,
revision=known_file.revision,
)
except IOError as exc_info:
logger.error(f"cannot reach huggingface {known_file.repo_id}/{known_file.filename}", exc_info=exc_info)
if known_file.convert_to_16_bit and file_size is not None and file_size != 0:
if path is not None and known_file.convert_to_16_bit and file_size is not None and file_size != 0:
tensors = {}
with safe_open(path, framework="pt") as f:
with tqdm.tqdm(total=len(f.keys())) as pb:
@ -151,20 +159,23 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
logger.info(f"Converted {path} to 16 bit, size is {os.stat(path, follow_symlinks=True).st_size}")
link_successful = True
if linked_filename is not None:
if path is not None:
destination_link = os.path.join(this_model_directory, linked_filename)
try:
os.makedirs(this_model_directory, exist_ok=True)
os.symlink(path, destination_link)
except Exception as exc_info:
logger.error("error while symbolic linking", exc_info=exc_info)
if Path(destination_link).is_file():
logger.warning(f"{known_file.repo_id}/{known_file.filename} could not link to {destination_link} because the path already exists, which is unexpected")
else:
try:
os.link(path, destination_link)
except Exception as hard_link_exc:
logger.error("error while hard linking", exc_info=hard_link_exc)
if cache_hit:
shutil.copyfile(path, destination_link)
link_successful = False
os.makedirs(this_model_directory, exist_ok=True)
os.symlink(path, destination_link)
except Exception as exc_info:
logger.error("error while symbolic linking", exc_info=exc_info)
try:
os.link(path, destination_link)
except Exception as hard_link_exc:
logger.error("error while hard linking", exc_info=hard_link_exc)
if cache_hit:
shutil.copyfile(path, destination_link)
link_successful = False
if not link_successful:
logger.error(f"Failed to link file with alternative download save name in a way that is compatible with Hugging Face caching {repr(known_file)}. If cache_hit={cache_hit} is True, the file was copied into the destination.", exc_info=exc_info)
@ -558,7 +569,9 @@ KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_ti2v_5B_fp16.safetensors"),
HuggingFile("lodestones/Chroma", "chroma-unlocked-v37.safetensors"),
HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "HighNoise/Wan2.2-T2V-A14B-HighNoise-Q8_0.gguf"),
HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "HighNoise/Wan2.2-T2V-A14B-HighNoise-Q4_K_M.gguf"),
HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "LowNoise/Wan2.2-T2V-A14B-LowNoise-Q8_0.gguf"),
HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "LowNoise/Wan2.2-T2V-A14B-LowNoise-Q4_K_M.gguf"),
], folder_names=["diffusion_models", "unet"])
KNOWN_CLIP_MODELS: Final[KnownDownloadables] = KnownDownloadables([

View File

@ -630,7 +630,7 @@ class ModelPatcher(ModelManageable):
# from gguf
if is_quantized(weight):
out_weight = weight.to(device_to)
patches = move_patch_to_device(self.patches[key], self.load_device if self.patch_on_device else self.offload_device)
patches = move_patch_to_device(self.patches[key], self.load_device if self.gguf.patch_on_device else self.offload_device)
# TODO: do we ever have legitimate duplicate patches? (i.e. patch on top of patched weight)
out_weight.patches = [(patches, key)]
if inplace_update:

View File

@ -674,7 +674,7 @@ class LoraLoader:
lora = utils.load_torch_file(lora_path, safe_load=True)
self.loaded_lora = (lora_path, lora)
model_lora, clip_lora = sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
model_lora, clip_lora = sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_name=lora_name)
return (model_lora, clip_lora)
class LoraLoaderModelOnly(LoraLoader):

View File

@ -15,7 +15,6 @@ import yaml
from . import clip_vision
from . import diffusers_convert
from . import gligen
from . import lora
from . import model_detection
from . import model_management
from . import model_patcher
@ -37,6 +36,7 @@ from .ldm.lightricks.vae import causal_video_autoencoder as lightricks
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
from .ldm.wan import vae as wan_vae
from .ldm.wan import vae2_2 as wan_vae2_2
from .lora import load_lora, model_lora_keys_unet, model_lora_keys_clip
from .lora_convert import convert_lora
from .model_management import load_models_gpu
from .model_patcher import ModelPatcher
@ -64,15 +64,15 @@ from .utils import ProgressBar, FileMetadata
logger = logging.getLogger(__name__)
def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_name=None):
key_map = {}
if model is not None:
key_map = lora.model_lora_keys_unet(model.model, key_map)
key_map = model_lora_keys_unet(model.model, key_map)
if clip is not None:
key_map = lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
key_map = model_lora_keys_clip(clip.cond_stage_model, key_map)
_lora = convert_lora(_lora)
loaded = lora.load_lora(_lora, key_map)
lora = convert_lora(lora)
loaded = load_lora(lora, key_map, lora_name=lora_name)
if model is not None:
new_modelpatcher: ModelPatcher = model.clone()
k = new_modelpatcher.add_patches(loaded, strength_model)
@ -90,7 +90,7 @@ def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
k1 = set(k1)
for x in loaded:
if (x not in k) and (x not in k1):
logger.warning("NOT LOADED {}".format(x))
logger.warning(f"[{lora_name}] clip keys not loaded {x}".format(x))
return (new_modelpatcher, new_clip)

View File

@ -80,7 +80,6 @@ class TorchCompileModel(CustomNode):
"backend": backend,
"mode": mode,
}
move_to_gpu = True
try:
if backend == "torch_tensorrt":
try:
@ -98,7 +97,6 @@ class TorchCompileModel(CustomNode):
"enable_weight_streaming": True,
"make_refittable": True,
}
move_to_gpu = True
del compile_kwargs["mode"]
if isinstance(model, (ModelPatcher, TransformersManagedModel, VAE)):
to_return = model.clone()
@ -109,27 +107,18 @@ class TorchCompileModel(CustomNode):
object_patches = ["encoder", "decoder"]
else:
patcher = to_return
if object_patch is None or len(object_patches) == 0:
if object_patch is None or len(object_patches) == 0 or len(object_patches) == 1 and object_patches[0].strip() == "":
object_patches = [DIFFUSION_MODEL]
if move_to_gpu:
model_management.unload_all_models()
model_management.load_models_gpu([patcher])
set_torch_compile_wrapper(patcher, keys=object_patches, **compile_kwargs)
# m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs))
# todo: do we want to move something back off the GPU?
# if move_to_gpu:
# model_management.unload_all_models()
return to_return,
elif isinstance(model, torch.nn.Module):
if move_to_gpu:
model_management.unload_all_models()
model.to(device=model_management.get_torch_device())
model_management.unload_all_models()
model.to(device=model_management.get_torch_device())
res = torch.compile(model=model, **compile_kwargs),
if move_to_gpu:
model.to(device=model_management.unet_offload_device())
model.to(device=model_management.unet_offload_device())
return res,
else:
logger.warning("Encountered a model that cannot be compiled")
logger.warning(f"Encountered a model {model} that cannot be compiled")
return model,
except OSError as os_error:
try:
@ -174,7 +163,8 @@ class QuantizeModel(CustomNode):
def execute(self, model: ModelPatcher, strategy: str = _QUANTIZATION_STRATEGIES[0]) -> tuple[ModelPatcher]:
model = model.clone()
unet = model.get_model_object("diffusion_model")
model.patch_model(force_patch_weights=True)
unet = model.diffusion_model
# todo: quantize quantizes in place, which is not desired
# default exclusions
@ -209,7 +199,7 @@ class QuantizeModel(CustomNode):
if "autoquant" in strategy:
_in_place_fixme = autoquant(unet, error_on_unseen=False)
else:
quantize_(unet, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device())
quantize_(unet, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device(), filter_fn=filter)
_in_place_fixme = unet
unwrap_tensor_subclass(_in_place_fixme)
else:

View File

@ -717,7 +717,7 @@ class LoraModelLoader:
if strength_model == 0:
return (model, )
model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0)
model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0, None)
return (model_lora, )