mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
Improvements to GGUF
- GGUF now works and included 4 bit quants, this will allow WAN to run on 24GB VRAM GPUs - logger only shows full stack for errors - helper functions for colab notebook - fix nvcr.io auth error - lora issues are now reported better - model downloader will use huggingface cache and symlinking if it is supported on your platform - torch compile node now correctly patches the model before compilation
This commit is contained in:
parent
69a4906964
commit
83184916c1
5
.github/workflows/docker-build.yml
vendored
5
.github/workflows/docker-build.yml
vendored
@ -25,6 +25,11 @@ jobs:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
- uses: docker/login-action@v3
|
||||
with:
|
||||
registry: nvcr.io
|
||||
username: "$oauthtoken"
|
||||
password: ${{ secrets.NVCR_NGC_TOKEN }}
|
||||
- name: Build and push CUDA (NVIDIA) image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
|
||||
138
comfy/app/colab.py
Normal file
138
comfy/app/colab.py
Normal file
@ -0,0 +1,138 @@
|
||||
import asyncio
|
||||
import os
|
||||
import stat
|
||||
import subprocess
|
||||
import threading
|
||||
from asyncio import Task
|
||||
from typing import NamedTuple
|
||||
|
||||
import requests
|
||||
|
||||
from ..cmd.folder_paths import init_default_paths, folder_names_and_paths
|
||||
# experimental workarounds for colab
|
||||
from ..cmd.main import _start_comfyui
|
||||
from ..execution_context import *
|
||||
|
||||
|
||||
class _ColabTuple(NamedTuple):
|
||||
tunnel: "CloudflaredTunnel"
|
||||
server: Task
|
||||
|
||||
|
||||
_colab_instances: list[_ColabTuple] = []
|
||||
|
||||
|
||||
class CloudflaredTunnel:
|
||||
"""
|
||||
A class to manage a cloudflared tunnel subprocess.
|
||||
|
||||
Provides methods to start, stop, and manage the lifecycle of the tunnel.
|
||||
It can be used as a context manager.
|
||||
"""
|
||||
|
||||
def __init__(self, port: int):
|
||||
self._port: int = port
|
||||
self._executable_path: str = "./cloudflared"
|
||||
self._process: Optional[subprocess.Popen] = None
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
|
||||
# Download and set permissions for the executable
|
||||
self._setup_executable()
|
||||
|
||||
# Start the tunnel process and capture the URL
|
||||
self.url: str = self._start_tunnel()
|
||||
|
||||
def _setup_executable(self):
|
||||
"""Downloads cloudflared and makes it executable if it doesn't exist."""
|
||||
if not os.path.exists(self._executable_path):
|
||||
url = "https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64"
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
with open(self._executable_path, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
# Make the file executable (add execute permission for the owner)
|
||||
current_permissions = os.stat(self._executable_path).st_mode
|
||||
os.chmod(self._executable_path, current_permissions | stat.S_IEXEC)
|
||||
|
||||
def _start_tunnel(self) -> str:
|
||||
"""Starts the tunnel and returns the public URL."""
|
||||
command = [self._executable_path, "tunnel", "--url", f"http://localhost:{self._port}", "--no-autoupdate"]
|
||||
|
||||
# Using DEVNULL for stderr to keep the output clean, stdout is piped
|
||||
self._process = subprocess.Popen(
|
||||
command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.DEVNULL,
|
||||
text=True
|
||||
)
|
||||
|
||||
for line in iter(self._process.stdout.readline, ""):
|
||||
if ".trycloudflare.com" in line:
|
||||
# The line format is typically: "INFO | https://<subdomain>.trycloudflare.com |"
|
||||
try:
|
||||
url = line.split("|")[1].strip()
|
||||
print(f"Tunnel is live at: {url}")
|
||||
return url
|
||||
except IndexError:
|
||||
continue
|
||||
|
||||
# If the loop finishes without finding a URL
|
||||
self.stop()
|
||||
raise RuntimeError("Failed to start cloudflared tunnel or find URL.")
|
||||
|
||||
def stop(self):
|
||||
"""Stops the cloudflared tunnel process."""
|
||||
if self._process and self._process.poll() is None:
|
||||
print("Stopping cloudflared tunnel...")
|
||||
self._process.terminate()
|
||||
try:
|
||||
self._process.wait(timeout=5)
|
||||
print("Tunnel stopped successfully.")
|
||||
except subprocess.TimeoutExpired:
|
||||
print("Tunnel did not terminate gracefully, forcing kill.")
|
||||
self._process.kill()
|
||||
self._process = None
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter context manager."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Exit context manager, ensuring the tunnel is stopped."""
|
||||
self.stop()
|
||||
|
||||
|
||||
def start_tunnel(port: int) -> CloudflaredTunnel:
|
||||
"""
|
||||
Initializes and starts a cloudflared tunnel.
|
||||
|
||||
Args:
|
||||
port: The local port number to expose.
|
||||
|
||||
Returns:
|
||||
A CloudflaredTunnel object that controls the tunnel process.
|
||||
This object has a `url` attribute and a `stop()` method.
|
||||
"""
|
||||
return CloudflaredTunnel(port)
|
||||
|
||||
|
||||
def start_server_in_colab() -> str:
|
||||
"""
|
||||
returns the URL of the tunnel and the running context
|
||||
:return:
|
||||
"""
|
||||
if len(_colab_instances) == 0:
|
||||
comfyui_execution_context.set(ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True), custom_nodes=ExportedNodes(), progress_registry=ProgressRegistryStub()))
|
||||
|
||||
async def colab_server_loop():
|
||||
init_default_paths(folder_names_and_paths)
|
||||
await _start_comfyui()
|
||||
|
||||
_loop = asyncio.get_running_loop()
|
||||
task = _loop.create_task(colab_server_loop())
|
||||
|
||||
tunnel = start_tunnel(8188)
|
||||
_colab_instances.append(_ColabTuple(tunnel, task))
|
||||
return _colab_instances[0].tunnel.url
|
||||
@ -55,7 +55,7 @@ def on_flush(callback):
|
||||
|
||||
class StackTraceLogger(logging.Logger):
|
||||
def _log(self, level, msg, args, exc_info=None, extra=None, stack_info=False, stacklevel=1):
|
||||
if level >= logging.WARNING:
|
||||
if level >= logging.ERROR:
|
||||
stack_info = True
|
||||
super()._log(level, msg, args, exc_info, extra, stack_info, stacklevel=stacklevel + 1)
|
||||
|
||||
|
||||
@ -171,6 +171,16 @@ class Comfy:
|
||||
def task_count(self) -> int:
|
||||
return self._task_count
|
||||
|
||||
def __enter__(self):
|
||||
self._is_running = True
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
get_event_loop().run_in_executor(self._executor, _cleanup)
|
||||
self._executor.shutdown(wait=True)
|
||||
self._is_running = False
|
||||
|
||||
|
||||
async def __aenter__(self):
|
||||
self._is_running = True
|
||||
return self
|
||||
|
||||
@ -19,15 +19,14 @@ from typing import List, Optional, Tuple, Literal
|
||||
import torch
|
||||
from opentelemetry.trace import get_current_span, StatusCode, Status
|
||||
|
||||
# order matters
|
||||
from .main_pre import tracer
|
||||
|
||||
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, \
|
||||
DependencyAwareCache, \
|
||||
BasicCache
|
||||
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||
from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||
from comfy_execution.utils import CurrentNodeContext
|
||||
# order matters
|
||||
from .main_pre import tracer
|
||||
from .. import interruption
|
||||
from .. import model_management
|
||||
from ..cli_args import args
|
||||
@ -38,7 +37,8 @@ from ..component_model.executor_types import ExecutorToClientProgress, Validatio
|
||||
HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage
|
||||
from ..component_model.files import canonicalize_path
|
||||
from ..component_model.module_property import create_module_properties
|
||||
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus, ExecutionStatusAsDict
|
||||
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus, \
|
||||
ExecutionStatusAsDict
|
||||
from ..execution_context import context_execute_node, context_execute_prompt
|
||||
from ..execution_ext import should_panic_on_exception
|
||||
from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
|
||||
@ -48,6 +48,7 @@ from ..progress import get_progress_state, reset_progress_state, add_progress_ha
|
||||
from ..validation import validate_node_input
|
||||
|
||||
_module_properties = create_module_properties()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@_module_properties.getter
|
||||
@ -100,12 +101,12 @@ class CacheSet:
|
||||
def __init__(self, cache_type=None, cache_size=None):
|
||||
if cache_type == CacheType.DEPENDENCY_AWARE:
|
||||
self.init_dependency_aware_cache()
|
||||
logging.info("Disabling intermediate node cache.")
|
||||
logger.info("Disabling intermediate node cache.")
|
||||
elif cache_type == CacheType.LRU:
|
||||
if cache_size is None:
|
||||
cache_size = 0
|
||||
self.init_lru_cache(cache_size)
|
||||
logging.info("Using LRU cache")
|
||||
logger.info("Using LRU cache")
|
||||
else:
|
||||
self.init_classic_cache()
|
||||
|
||||
@ -571,7 +572,7 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
|
||||
return RecursiveExecutionTuple(ExecutionResult.PENDING, None, None)
|
||||
caches.outputs.set(unique_id, output_data)
|
||||
except interruption.InterruptProcessingException as iex:
|
||||
logging.info("Processing interrupted")
|
||||
logger.info("Processing interrupted")
|
||||
|
||||
# skip formatting inputs/outputs
|
||||
error_details: RecursiveExecutionErrorDetailsInterrupted = {
|
||||
@ -588,13 +589,13 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
|
||||
for name, inputs in input_data_all.items():
|
||||
input_data_formatted[name] = [format_value(x) for x in inputs]
|
||||
|
||||
logging.error("An error occurred while executing a workflow", exc_info=ex)
|
||||
logging.error(traceback.format_exc())
|
||||
logger.error("An error occurred while executing a workflow", exc_info=ex)
|
||||
logger.error(traceback.format_exc())
|
||||
tips = ""
|
||||
|
||||
if isinstance(ex, model_management.OOM_EXCEPTION):
|
||||
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
|
||||
logging.error("Got an OOM, unloading all loaded models.")
|
||||
logger.error("Got an OOM, unloading all loaded models.")
|
||||
model_management.unload_all_models()
|
||||
|
||||
error_details: RecursiveExecutionErrorDetails = {
|
||||
@ -606,7 +607,7 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
|
||||
}
|
||||
|
||||
if should_panic_on_exception(ex, args.panic_when):
|
||||
logging.error(f"The exception {ex} was configured as unrecoverable, scheduling an exit")
|
||||
logger.error(f"The exception {ex} was configured as unrecoverable, scheduling an exit")
|
||||
|
||||
def sys_exit(*args):
|
||||
sys.exit(1)
|
||||
@ -1120,11 +1121,11 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty
|
||||
if valid is True:
|
||||
good_outputs.add(o)
|
||||
else:
|
||||
logging.error(f"Failed to validate prompt for output {o}:")
|
||||
logger.error(f"Failed to validate prompt for output {o}:")
|
||||
if len(reasons) > 0:
|
||||
logging.error("* (prompt):")
|
||||
logger.error("* (prompt):")
|
||||
for reason in reasons:
|
||||
logging.error(f" - {reason['message']}: {reason['details']}")
|
||||
logger.error(f" - {reason['message']}: {reason['details']}")
|
||||
errors += [(o, reasons)]
|
||||
for node_id, result in validated.items():
|
||||
valid = result[0]
|
||||
@ -1140,11 +1141,11 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty
|
||||
"dependent_outputs": [],
|
||||
"class_type": class_type
|
||||
}
|
||||
logging.error(f"* {class_type} {node_id}:")
|
||||
logger.error(f"* {class_type} {node_id}:")
|
||||
for reason in reasons:
|
||||
logging.error(f" - {reason['message']}: {reason['details']}")
|
||||
logger.error(f" - {reason['message']}: {reason['details']}")
|
||||
node_errors[node_id]["dependent_outputs"].append(o)
|
||||
logging.error("Output will be ignored")
|
||||
logger.error("Output will be ignored")
|
||||
|
||||
if len(good_outputs) == 0:
|
||||
errors_list = []
|
||||
|
||||
@ -93,7 +93,7 @@ def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optio
|
||||
ModelPaths(["checkpoints"], supported_extensions=set(supported_pt_extensions)),
|
||||
ModelPaths(["configs"], additional_absolute_directory_paths=[get_package_as_path("comfy.configs")], supported_extensions={".yaml"}),
|
||||
ModelPaths(["vae"], supported_extensions=set(supported_pt_extensions)),
|
||||
ModelPaths(folder_names=["clip", "text_encoders"], supported_extensions=set(supported_pt_extensions)),
|
||||
ModelPaths(folder_names=["text_encoders", "clip"], supported_extensions=set(supported_pt_extensions)),
|
||||
ModelPaths(["loras"], supported_extensions=set(supported_pt_extensions)),
|
||||
ModelPaths(folder_names=["diffusion_models", "unet"], supported_extensions=set(supported_pt_extensions), folder_names_are_relative_directory_paths_too=True),
|
||||
ModelPaths(["clip_vision"], supported_extensions=set(supported_pt_extensions)),
|
||||
|
||||
@ -939,16 +939,16 @@ def get_torch_compiler_disable_decorator():
|
||||
from packaging import version
|
||||
|
||||
if not chained_hasattr(torch, "compiler.disable"):
|
||||
logger.info("ComfyUI-GGUF: Torch too old for torch.compile - bypassing")
|
||||
logger.debug("ComfyUI-GGUF: Torch too old for torch.compile - bypassing")
|
||||
return dummy_decorator # torch too old
|
||||
elif version.parse(torch.__version__) >= version.parse("2.8"):
|
||||
logger.info("ComfyUI-GGUF: Allowing full torch compile")
|
||||
logger.debug("ComfyUI-GGUF: Allowing full torch compile")
|
||||
return dummy_decorator # torch compile works
|
||||
if chained_hasattr(torch, "_dynamo.config.nontraceable_tensor_subclasses"):
|
||||
logger.info("ComfyUI-GGUF: Allowing full torch compile (nightly)")
|
||||
logger.debug("ComfyUI-GGUF: Allowing full torch compile (nightly)")
|
||||
return dummy_decorator # torch compile works, nightly before 2.8 release
|
||||
else:
|
||||
logger.info("ComfyUI-GGUF: Partial torch compile only, consider updating pytorch")
|
||||
logger.debug("ComfyUI-GGUF: Partial torch compile only, consider updating pytorch")
|
||||
return torch.compiler.disable
|
||||
|
||||
|
||||
|
||||
@ -24,8 +24,10 @@ import torch
|
||||
from . import model_base
|
||||
from . import model_management
|
||||
from . import utils
|
||||
from .lora_types import PatchDict, PatchOffset, PatchConversionFunction, PatchType, ModelPatchesDictValue
|
||||
from . import weight_adapter
|
||||
from .lora_types import PatchDict, PatchOffset, PatchConversionFunction, PatchType, ModelPatchesDictValue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LORA_CLIP_MAP = {
|
||||
"mlp.fc1": "mlp_fc1",
|
||||
@ -37,7 +39,7 @@ LORA_CLIP_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def load_lora(lora, to_load, log_missing=True) -> PatchDict:
|
||||
def load_lora(lora, to_load, log_missing=True, lora_name=None) -> PatchDict:
|
||||
patch_dict: PatchDict = {}
|
||||
loaded_keys = set()
|
||||
for x in to_load:
|
||||
@ -91,9 +93,10 @@ def load_lora(lora, to_load, log_missing=True) -> PatchDict:
|
||||
loaded_keys.add(set_weight_name)
|
||||
|
||||
if log_missing:
|
||||
for x in lora.keys():
|
||||
if x not in loaded_keys:
|
||||
logging.warning("lora key not loaded: {}".format(x))
|
||||
not_loaded_keys = [x for x in lora.keys() if x not in loaded_keys]
|
||||
n_not_loaded_keys = len(not_loaded_keys)
|
||||
if n_not_loaded_keys > 0:
|
||||
logger.warning(f"[{lora_name}] lora keys not loaded ({n_not_loaded_keys} / {len(loaded_keys) + n_not_loaded_keys}): {not_loaded_keys}")
|
||||
|
||||
return patch_dict
|
||||
|
||||
@ -293,12 +296,12 @@ def model_lora_keys_unet(model, key_map=None):
|
||||
if k.startswith("diffusion_model."):
|
||||
if k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
|
||||
key_map["transformer.{}".format(key_lora)] = k #SimpleTuner regular format
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k # SimpleTuner lycoris format
|
||||
key_map["transformer.{}".format(key_lora)] = k # SimpleTuner regular format
|
||||
|
||||
if isinstance(model, model_base.ACEStep):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official ACE step lora format
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"): # Official ACE step lora format
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["{}".format(key_lora)] = k
|
||||
|
||||
@ -364,7 +367,7 @@ def calculate_weight(patches: ModelPatchesDictValue, weight, key, intermediate_d
|
||||
if isinstance(v, weight_adapter.WeightAdapterBase):
|
||||
output = v.calculate_weight(weight, key, strength, strength_model, offset, function, intermediate_dtype, original_weights)
|
||||
if output is None:
|
||||
logging.warning("Calculate Weight Failed: {} {}".format(v.name, key))
|
||||
logger.warning("Calculate Weight Failed: {} {}".format(v.name, key))
|
||||
else:
|
||||
weight = output
|
||||
if old_weight is not None:
|
||||
@ -382,12 +385,12 @@ def calculate_weight(patches: ModelPatchesDictValue, weight, key, intermediate_d
|
||||
# An extra flag to pad the weight if the diff's shape is larger than the weight
|
||||
do_pad_weight = len(v) > 1 and v[1]['pad_weight']
|
||||
if do_pad_weight and diff.shape != weight.shape:
|
||||
logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape))
|
||||
logger.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape))
|
||||
weight = pad_tensor_to_shape(weight, diff.shape)
|
||||
|
||||
if strength != 0.0:
|
||||
if diff.shape != weight.shape:
|
||||
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
|
||||
logger.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
|
||||
else:
|
||||
weight += function(strength * model_management.cast_to_device(diff, weight.device, weight.dtype))
|
||||
elif patch_type == "set":
|
||||
@ -398,7 +401,7 @@ def calculate_weight(patches: ModelPatchesDictValue, weight, key, intermediate_d
|
||||
model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
|
||||
weight += function(strength * model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
|
||||
else:
|
||||
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
||||
logger.warning("patch type not recognized {} {}".format(patch_type, key))
|
||||
|
||||
if old_weight is not None:
|
||||
weight = old_weight
|
||||
|
||||
@ -1,8 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from itertools import chain
|
||||
from os.path import join
|
||||
|
||||
import collections
|
||||
import logging
|
||||
import operator
|
||||
@ -10,6 +7,8 @@ import os
|
||||
import shutil
|
||||
from collections.abc import Sequence, MutableSequence
|
||||
from functools import reduce
|
||||
from itertools import chain
|
||||
from os.path import join
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Final, Set
|
||||
|
||||
@ -66,28 +65,38 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
|
||||
filename = canonicalize_path(filename)
|
||||
path = folder_paths.get_full_path(folder_name, filename)
|
||||
|
||||
candidate_str_match = False
|
||||
candidate_filename_match = False
|
||||
candidate_alternate_filenames_match = False
|
||||
candidate_save_filename_match = False
|
||||
if path is None and not args.disable_known_models:
|
||||
try:
|
||||
# todo: should this be the first or last path?
|
||||
this_model_directory = folder_paths.get_folder_paths(folder_name)[0]
|
||||
known_file: Optional[HuggingFile | CivitFile] = None
|
||||
for candidate in known_files:
|
||||
if (canonicalize_path(str(candidate)) == filename
|
||||
or canonicalize_path(candidate.filename) == filename
|
||||
or filename in list(map(canonicalize_path, candidate.alternate_filenames))
|
||||
or filename == canonicalize_path(candidate.save_with_filename)):
|
||||
candidate_str_match = canonicalize_path(str(candidate)) == filename
|
||||
candidate_filename_match = canonicalize_path(candidate.filename) == filename
|
||||
candidate_alternate_filenames_match = filename in list(map(canonicalize_path, candidate.alternate_filenames))
|
||||
candidate_save_filename_match = filename == canonicalize_path(candidate.save_with_filename)
|
||||
if (candidate_str_match
|
||||
or candidate_filename_match
|
||||
or candidate_alternate_filenames_match
|
||||
or candidate_save_filename_match):
|
||||
known_file = candidate
|
||||
break
|
||||
if known_file is None:
|
||||
logger.debug(f"get_or_download could not find {filename} in {folder_name}, known_files={known_files}")
|
||||
return path
|
||||
with comfy_tqdm():
|
||||
if isinstance(known_file, HuggingFile):
|
||||
symlinks_supported = are_symlinks_supported()
|
||||
if known_file.save_with_filename is not None:
|
||||
linked_filename = known_file.save_with_filename
|
||||
elif not known_file.force_save_in_repo_id and os.path.basename(known_file.filename) != known_file.filename:
|
||||
linked_filename = os.path.basename(known_file.filename)
|
||||
else:
|
||||
linked_filename = None
|
||||
linked_filename = known_file.filename
|
||||
|
||||
if known_file.force_save_in_repo_id or linked_filename is not None and os.path.dirname(known_file.filename) == "":
|
||||
# if the known file has an overridden linked name, save it into a repo_id sub directory
|
||||
@ -112,8 +121,6 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
|
||||
|
||||
cache_hit = False
|
||||
try:
|
||||
if not are_symlinks_supported():
|
||||
raise PermissionError("no symlink support")
|
||||
# always retrieve this from the cache if it already exists there
|
||||
path = hf_hub_download(repo_id=known_file.repo_id,
|
||||
filename=known_file.filename,
|
||||
@ -121,19 +128,20 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
|
||||
revision=known_file.revision,
|
||||
local_files_only=True,
|
||||
)
|
||||
logger.info(f"hf_hub_download cache hit for {known_file.repo_id}/{known_file.filename}")
|
||||
if linked_filename is None:
|
||||
linked_filename = known_file.filename
|
||||
logger.debug(f"hf_hub_download cache hit for {known_file.repo_id}/{known_file.filename}")
|
||||
cache_hit = True
|
||||
except (LocalEntryNotFoundError, PermissionError):
|
||||
path = hf_hub_download(repo_id=known_file.repo_id,
|
||||
filename=known_file.filename,
|
||||
local_dir=hf_destination_dir,
|
||||
repo_type=known_file.repo_type,
|
||||
revision=known_file.revision,
|
||||
)
|
||||
except LocalEntryNotFoundError:
|
||||
try:
|
||||
logger.debug(f"{folder_name}/{filename} is being downloaded from {known_file.repo_id}/{known_file.filename} candidate_str_match={candidate_str_match} candidate_filename_match={candidate_filename_match} candidate_alternate_filenames_match={candidate_alternate_filenames_match} candidate_save_filename_match={candidate_save_filename_match}")
|
||||
path = hf_hub_download(repo_id=known_file.repo_id,
|
||||
filename=known_file.filename,
|
||||
repo_type=known_file.repo_type,
|
||||
revision=known_file.revision,
|
||||
)
|
||||
except IOError as exc_info:
|
||||
logger.error(f"cannot reach huggingface {known_file.repo_id}/{known_file.filename}", exc_info=exc_info)
|
||||
|
||||
if known_file.convert_to_16_bit and file_size is not None and file_size != 0:
|
||||
if path is not None and known_file.convert_to_16_bit and file_size is not None and file_size != 0:
|
||||
tensors = {}
|
||||
with safe_open(path, framework="pt") as f:
|
||||
with tqdm.tqdm(total=len(f.keys())) as pb:
|
||||
@ -151,20 +159,23 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
|
||||
logger.info(f"Converted {path} to 16 bit, size is {os.stat(path, follow_symlinks=True).st_size}")
|
||||
|
||||
link_successful = True
|
||||
if linked_filename is not None:
|
||||
if path is not None:
|
||||
destination_link = os.path.join(this_model_directory, linked_filename)
|
||||
try:
|
||||
os.makedirs(this_model_directory, exist_ok=True)
|
||||
os.symlink(path, destination_link)
|
||||
except Exception as exc_info:
|
||||
logger.error("error while symbolic linking", exc_info=exc_info)
|
||||
if Path(destination_link).is_file():
|
||||
logger.warning(f"{known_file.repo_id}/{known_file.filename} could not link to {destination_link} because the path already exists, which is unexpected")
|
||||
else:
|
||||
try:
|
||||
os.link(path, destination_link)
|
||||
except Exception as hard_link_exc:
|
||||
logger.error("error while hard linking", exc_info=hard_link_exc)
|
||||
if cache_hit:
|
||||
shutil.copyfile(path, destination_link)
|
||||
link_successful = False
|
||||
os.makedirs(this_model_directory, exist_ok=True)
|
||||
os.symlink(path, destination_link)
|
||||
except Exception as exc_info:
|
||||
logger.error("error while symbolic linking", exc_info=exc_info)
|
||||
try:
|
||||
os.link(path, destination_link)
|
||||
except Exception as hard_link_exc:
|
||||
logger.error("error while hard linking", exc_info=hard_link_exc)
|
||||
if cache_hit:
|
||||
shutil.copyfile(path, destination_link)
|
||||
link_successful = False
|
||||
|
||||
if not link_successful:
|
||||
logger.error(f"Failed to link file with alternative download save name in a way that is compatible with Hugging Face caching {repr(known_file)}. If cache_hit={cache_hit} is True, the file was copied into the destination.", exc_info=exc_info)
|
||||
@ -558,7 +569,9 @@ KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([
|
||||
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_ti2v_5B_fp16.safetensors"),
|
||||
HuggingFile("lodestones/Chroma", "chroma-unlocked-v37.safetensors"),
|
||||
HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "HighNoise/Wan2.2-T2V-A14B-HighNoise-Q8_0.gguf"),
|
||||
HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "HighNoise/Wan2.2-T2V-A14B-HighNoise-Q4_K_M.gguf"),
|
||||
HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "LowNoise/Wan2.2-T2V-A14B-LowNoise-Q8_0.gguf"),
|
||||
HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "LowNoise/Wan2.2-T2V-A14B-LowNoise-Q4_K_M.gguf"),
|
||||
], folder_names=["diffusion_models", "unet"])
|
||||
|
||||
KNOWN_CLIP_MODELS: Final[KnownDownloadables] = KnownDownloadables([
|
||||
|
||||
@ -630,7 +630,7 @@ class ModelPatcher(ModelManageable):
|
||||
# from gguf
|
||||
if is_quantized(weight):
|
||||
out_weight = weight.to(device_to)
|
||||
patches = move_patch_to_device(self.patches[key], self.load_device if self.patch_on_device else self.offload_device)
|
||||
patches = move_patch_to_device(self.patches[key], self.load_device if self.gguf.patch_on_device else self.offload_device)
|
||||
# TODO: do we ever have legitimate duplicate patches? (i.e. patch on top of patched weight)
|
||||
out_weight.patches = [(patches, key)]
|
||||
if inplace_update:
|
||||
|
||||
@ -674,7 +674,7 @@ class LoraLoader:
|
||||
lora = utils.load_torch_file(lora_path, safe_load=True)
|
||||
self.loaded_lora = (lora_path, lora)
|
||||
|
||||
model_lora, clip_lora = sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
|
||||
model_lora, clip_lora = sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_name=lora_name)
|
||||
return (model_lora, clip_lora)
|
||||
|
||||
class LoraLoaderModelOnly(LoraLoader):
|
||||
|
||||
14
comfy/sd.py
14
comfy/sd.py
@ -15,7 +15,6 @@ import yaml
|
||||
from . import clip_vision
|
||||
from . import diffusers_convert
|
||||
from . import gligen
|
||||
from . import lora
|
||||
from . import model_detection
|
||||
from . import model_management
|
||||
from . import model_patcher
|
||||
@ -37,6 +36,7 @@ from .ldm.lightricks.vae import causal_video_autoencoder as lightricks
|
||||
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
||||
from .ldm.wan import vae as wan_vae
|
||||
from .ldm.wan import vae2_2 as wan_vae2_2
|
||||
from .lora import load_lora, model_lora_keys_unet, model_lora_keys_clip
|
||||
from .lora_convert import convert_lora
|
||||
from .model_management import load_models_gpu
|
||||
from .model_patcher import ModelPatcher
|
||||
@ -64,15 +64,15 @@ from .utils import ProgressBar, FileMetadata
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
|
||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_name=None):
|
||||
key_map = {}
|
||||
if model is not None:
|
||||
key_map = lora.model_lora_keys_unet(model.model, key_map)
|
||||
key_map = model_lora_keys_unet(model.model, key_map)
|
||||
if clip is not None:
|
||||
key_map = lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||
key_map = model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||
|
||||
_lora = convert_lora(_lora)
|
||||
loaded = lora.load_lora(_lora, key_map)
|
||||
lora = convert_lora(lora)
|
||||
loaded = load_lora(lora, key_map, lora_name=lora_name)
|
||||
if model is not None:
|
||||
new_modelpatcher: ModelPatcher = model.clone()
|
||||
k = new_modelpatcher.add_patches(loaded, strength_model)
|
||||
@ -90,7 +90,7 @@ def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
|
||||
k1 = set(k1)
|
||||
for x in loaded:
|
||||
if (x not in k) and (x not in k1):
|
||||
logger.warning("NOT LOADED {}".format(x))
|
||||
logger.warning(f"[{lora_name}] clip keys not loaded {x}".format(x))
|
||||
|
||||
return (new_modelpatcher, new_clip)
|
||||
|
||||
|
||||
@ -80,7 +80,6 @@ class TorchCompileModel(CustomNode):
|
||||
"backend": backend,
|
||||
"mode": mode,
|
||||
}
|
||||
move_to_gpu = True
|
||||
try:
|
||||
if backend == "torch_tensorrt":
|
||||
try:
|
||||
@ -98,7 +97,6 @@ class TorchCompileModel(CustomNode):
|
||||
"enable_weight_streaming": True,
|
||||
"make_refittable": True,
|
||||
}
|
||||
move_to_gpu = True
|
||||
del compile_kwargs["mode"]
|
||||
if isinstance(model, (ModelPatcher, TransformersManagedModel, VAE)):
|
||||
to_return = model.clone()
|
||||
@ -109,27 +107,18 @@ class TorchCompileModel(CustomNode):
|
||||
object_patches = ["encoder", "decoder"]
|
||||
else:
|
||||
patcher = to_return
|
||||
if object_patch is None or len(object_patches) == 0:
|
||||
if object_patch is None or len(object_patches) == 0 or len(object_patches) == 1 and object_patches[0].strip() == "":
|
||||
object_patches = [DIFFUSION_MODEL]
|
||||
if move_to_gpu:
|
||||
model_management.unload_all_models()
|
||||
model_management.load_models_gpu([patcher])
|
||||
set_torch_compile_wrapper(patcher, keys=object_patches, **compile_kwargs)
|
||||
# m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs))
|
||||
# todo: do we want to move something back off the GPU?
|
||||
# if move_to_gpu:
|
||||
# model_management.unload_all_models()
|
||||
return to_return,
|
||||
elif isinstance(model, torch.nn.Module):
|
||||
if move_to_gpu:
|
||||
model_management.unload_all_models()
|
||||
model.to(device=model_management.get_torch_device())
|
||||
model_management.unload_all_models()
|
||||
model.to(device=model_management.get_torch_device())
|
||||
res = torch.compile(model=model, **compile_kwargs),
|
||||
if move_to_gpu:
|
||||
model.to(device=model_management.unet_offload_device())
|
||||
model.to(device=model_management.unet_offload_device())
|
||||
return res,
|
||||
else:
|
||||
logger.warning("Encountered a model that cannot be compiled")
|
||||
logger.warning(f"Encountered a model {model} that cannot be compiled")
|
||||
return model,
|
||||
except OSError as os_error:
|
||||
try:
|
||||
@ -174,7 +163,8 @@ class QuantizeModel(CustomNode):
|
||||
|
||||
def execute(self, model: ModelPatcher, strategy: str = _QUANTIZATION_STRATEGIES[0]) -> tuple[ModelPatcher]:
|
||||
model = model.clone()
|
||||
unet = model.get_model_object("diffusion_model")
|
||||
model.patch_model(force_patch_weights=True)
|
||||
unet = model.diffusion_model
|
||||
# todo: quantize quantizes in place, which is not desired
|
||||
|
||||
# default exclusions
|
||||
@ -209,7 +199,7 @@ class QuantizeModel(CustomNode):
|
||||
if "autoquant" in strategy:
|
||||
_in_place_fixme = autoquant(unet, error_on_unseen=False)
|
||||
else:
|
||||
quantize_(unet, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device())
|
||||
quantize_(unet, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device(), filter_fn=filter)
|
||||
_in_place_fixme = unet
|
||||
unwrap_tensor_subclass(_in_place_fixme)
|
||||
else:
|
||||
|
||||
@ -717,7 +717,7 @@ class LoraModelLoader:
|
||||
if strength_model == 0:
|
||||
return (model, )
|
||||
|
||||
model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0)
|
||||
model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0, None)
|
||||
return (model_lora, )
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user