mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 05:22:34 +08:00
Improvements to GGUF
- GGUF now works and included 4 bit quants, this will allow WAN to run on 24GB VRAM GPUs - logger only shows full stack for errors - helper functions for colab notebook - fix nvcr.io auth error - lora issues are now reported better - model downloader will use huggingface cache and symlinking if it is supported on your platform - torch compile node now correctly patches the model before compilation
This commit is contained in:
parent
69a4906964
commit
83184916c1
5
.github/workflows/docker-build.yml
vendored
5
.github/workflows/docker-build.yml
vendored
@ -25,6 +25,11 @@ jobs:
|
|||||||
registry: ${{ env.REGISTRY }}
|
registry: ${{ env.REGISTRY }}
|
||||||
username: ${{ github.actor }}
|
username: ${{ github.actor }}
|
||||||
password: ${{ secrets.GITHUB_TOKEN }}
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
- uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
registry: nvcr.io
|
||||||
|
username: "$oauthtoken"
|
||||||
|
password: ${{ secrets.NVCR_NGC_TOKEN }}
|
||||||
- name: Build and push CUDA (NVIDIA) image
|
- name: Build and push CUDA (NVIDIA) image
|
||||||
uses: docker/build-push-action@v6
|
uses: docker/build-push-action@v6
|
||||||
with:
|
with:
|
||||||
|
|||||||
138
comfy/app/colab.py
Normal file
138
comfy/app/colab.py
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import stat
|
||||||
|
import subprocess
|
||||||
|
import threading
|
||||||
|
from asyncio import Task
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from ..cmd.folder_paths import init_default_paths, folder_names_and_paths
|
||||||
|
# experimental workarounds for colab
|
||||||
|
from ..cmd.main import _start_comfyui
|
||||||
|
from ..execution_context import *
|
||||||
|
|
||||||
|
|
||||||
|
class _ColabTuple(NamedTuple):
|
||||||
|
tunnel: "CloudflaredTunnel"
|
||||||
|
server: Task
|
||||||
|
|
||||||
|
|
||||||
|
_colab_instances: list[_ColabTuple] = []
|
||||||
|
|
||||||
|
|
||||||
|
class CloudflaredTunnel:
|
||||||
|
"""
|
||||||
|
A class to manage a cloudflared tunnel subprocess.
|
||||||
|
|
||||||
|
Provides methods to start, stop, and manage the lifecycle of the tunnel.
|
||||||
|
It can be used as a context manager.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, port: int):
|
||||||
|
self._port: int = port
|
||||||
|
self._executable_path: str = "./cloudflared"
|
||||||
|
self._process: Optional[subprocess.Popen] = None
|
||||||
|
self._thread: Optional[threading.Thread] = None
|
||||||
|
|
||||||
|
# Download and set permissions for the executable
|
||||||
|
self._setup_executable()
|
||||||
|
|
||||||
|
# Start the tunnel process and capture the URL
|
||||||
|
self.url: str = self._start_tunnel()
|
||||||
|
|
||||||
|
def _setup_executable(self):
|
||||||
|
"""Downloads cloudflared and makes it executable if it doesn't exist."""
|
||||||
|
if not os.path.exists(self._executable_path):
|
||||||
|
url = "https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64"
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
with open(self._executable_path, "wb") as f:
|
||||||
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
# Make the file executable (add execute permission for the owner)
|
||||||
|
current_permissions = os.stat(self._executable_path).st_mode
|
||||||
|
os.chmod(self._executable_path, current_permissions | stat.S_IEXEC)
|
||||||
|
|
||||||
|
def _start_tunnel(self) -> str:
|
||||||
|
"""Starts the tunnel and returns the public URL."""
|
||||||
|
command = [self._executable_path, "tunnel", "--url", f"http://localhost:{self._port}", "--no-autoupdate"]
|
||||||
|
|
||||||
|
# Using DEVNULL for stderr to keep the output clean, stdout is piped
|
||||||
|
self._process = subprocess.Popen(
|
||||||
|
command,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.DEVNULL,
|
||||||
|
text=True
|
||||||
|
)
|
||||||
|
|
||||||
|
for line in iter(self._process.stdout.readline, ""):
|
||||||
|
if ".trycloudflare.com" in line:
|
||||||
|
# The line format is typically: "INFO | https://<subdomain>.trycloudflare.com |"
|
||||||
|
try:
|
||||||
|
url = line.split("|")[1].strip()
|
||||||
|
print(f"Tunnel is live at: {url}")
|
||||||
|
return url
|
||||||
|
except IndexError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If the loop finishes without finding a URL
|
||||||
|
self.stop()
|
||||||
|
raise RuntimeError("Failed to start cloudflared tunnel or find URL.")
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""Stops the cloudflared tunnel process."""
|
||||||
|
if self._process and self._process.poll() is None:
|
||||||
|
print("Stopping cloudflared tunnel...")
|
||||||
|
self._process.terminate()
|
||||||
|
try:
|
||||||
|
self._process.wait(timeout=5)
|
||||||
|
print("Tunnel stopped successfully.")
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
print("Tunnel did not terminate gracefully, forcing kill.")
|
||||||
|
self._process.kill()
|
||||||
|
self._process = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""Enter context manager."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Exit context manager, ensuring the tunnel is stopped."""
|
||||||
|
self.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def start_tunnel(port: int) -> CloudflaredTunnel:
|
||||||
|
"""
|
||||||
|
Initializes and starts a cloudflared tunnel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
port: The local port number to expose.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A CloudflaredTunnel object that controls the tunnel process.
|
||||||
|
This object has a `url` attribute and a `stop()` method.
|
||||||
|
"""
|
||||||
|
return CloudflaredTunnel(port)
|
||||||
|
|
||||||
|
|
||||||
|
def start_server_in_colab() -> str:
|
||||||
|
"""
|
||||||
|
returns the URL of the tunnel and the running context
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if len(_colab_instances) == 0:
|
||||||
|
comfyui_execution_context.set(ExecutionContext(server=ServerStub(), folder_names_and_paths=FolderNames(is_root=True), custom_nodes=ExportedNodes(), progress_registry=ProgressRegistryStub()))
|
||||||
|
|
||||||
|
async def colab_server_loop():
|
||||||
|
init_default_paths(folder_names_and_paths)
|
||||||
|
await _start_comfyui()
|
||||||
|
|
||||||
|
_loop = asyncio.get_running_loop()
|
||||||
|
task = _loop.create_task(colab_server_loop())
|
||||||
|
|
||||||
|
tunnel = start_tunnel(8188)
|
||||||
|
_colab_instances.append(_ColabTuple(tunnel, task))
|
||||||
|
return _colab_instances[0].tunnel.url
|
||||||
@ -55,7 +55,7 @@ def on_flush(callback):
|
|||||||
|
|
||||||
class StackTraceLogger(logging.Logger):
|
class StackTraceLogger(logging.Logger):
|
||||||
def _log(self, level, msg, args, exc_info=None, extra=None, stack_info=False, stacklevel=1):
|
def _log(self, level, msg, args, exc_info=None, extra=None, stack_info=False, stacklevel=1):
|
||||||
if level >= logging.WARNING:
|
if level >= logging.ERROR:
|
||||||
stack_info = True
|
stack_info = True
|
||||||
super()._log(level, msg, args, exc_info, extra, stack_info, stacklevel=stacklevel + 1)
|
super()._log(level, msg, args, exc_info, extra, stack_info, stacklevel=stacklevel + 1)
|
||||||
|
|
||||||
|
|||||||
@ -171,6 +171,16 @@ class Comfy:
|
|||||||
def task_count(self) -> int:
|
def task_count(self) -> int:
|
||||||
return self._task_count
|
return self._task_count
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self._is_running = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
get_event_loop().run_in_executor(self._executor, _cleanup)
|
||||||
|
self._executor.shutdown(wait=True)
|
||||||
|
self._is_running = False
|
||||||
|
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
self._is_running = True
|
self._is_running = True
|
||||||
return self
|
return self
|
||||||
|
|||||||
@ -19,15 +19,14 @@ from typing import List, Optional, Tuple, Literal
|
|||||||
import torch
|
import torch
|
||||||
from opentelemetry.trace import get_current_span, StatusCode, Status
|
from opentelemetry.trace import get_current_span, StatusCode, Status
|
||||||
|
|
||||||
# order matters
|
|
||||||
from .main_pre import tracer
|
|
||||||
|
|
||||||
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, \
|
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, \
|
||||||
DependencyAwareCache, \
|
DependencyAwareCache, \
|
||||||
BasicCache
|
BasicCache
|
||||||
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||||
from comfy_execution.graph_utils import is_link, GraphBuilder
|
from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||||
from comfy_execution.utils import CurrentNodeContext
|
from comfy_execution.utils import CurrentNodeContext
|
||||||
|
# order matters
|
||||||
|
from .main_pre import tracer
|
||||||
from .. import interruption
|
from .. import interruption
|
||||||
from .. import model_management
|
from .. import model_management
|
||||||
from ..cli_args import args
|
from ..cli_args import args
|
||||||
@ -38,7 +37,8 @@ from ..component_model.executor_types import ExecutorToClientProgress, Validatio
|
|||||||
HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage
|
HistoryResultDict, ExecutionErrorMessage, ExecutionInterruptedMessage
|
||||||
from ..component_model.files import canonicalize_path
|
from ..component_model.files import canonicalize_path
|
||||||
from ..component_model.module_property import create_module_properties
|
from ..component_model.module_property import create_module_properties
|
||||||
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus, ExecutionStatusAsDict
|
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus, \
|
||||||
|
ExecutionStatusAsDict
|
||||||
from ..execution_context import context_execute_node, context_execute_prompt
|
from ..execution_context import context_execute_node, context_execute_prompt
|
||||||
from ..execution_ext import should_panic_on_exception
|
from ..execution_ext import should_panic_on_exception
|
||||||
from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
|
from ..nodes.package_typing import InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
|
||||||
@ -48,6 +48,7 @@ from ..progress import get_progress_state, reset_progress_state, add_progress_ha
|
|||||||
from ..validation import validate_node_input
|
from ..validation import validate_node_input
|
||||||
|
|
||||||
_module_properties = create_module_properties()
|
_module_properties = create_module_properties()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@_module_properties.getter
|
@_module_properties.getter
|
||||||
@ -100,12 +101,12 @@ class CacheSet:
|
|||||||
def __init__(self, cache_type=None, cache_size=None):
|
def __init__(self, cache_type=None, cache_size=None):
|
||||||
if cache_type == CacheType.DEPENDENCY_AWARE:
|
if cache_type == CacheType.DEPENDENCY_AWARE:
|
||||||
self.init_dependency_aware_cache()
|
self.init_dependency_aware_cache()
|
||||||
logging.info("Disabling intermediate node cache.")
|
logger.info("Disabling intermediate node cache.")
|
||||||
elif cache_type == CacheType.LRU:
|
elif cache_type == CacheType.LRU:
|
||||||
if cache_size is None:
|
if cache_size is None:
|
||||||
cache_size = 0
|
cache_size = 0
|
||||||
self.init_lru_cache(cache_size)
|
self.init_lru_cache(cache_size)
|
||||||
logging.info("Using LRU cache")
|
logger.info("Using LRU cache")
|
||||||
else:
|
else:
|
||||||
self.init_classic_cache()
|
self.init_classic_cache()
|
||||||
|
|
||||||
@ -571,7 +572,7 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
|
|||||||
return RecursiveExecutionTuple(ExecutionResult.PENDING, None, None)
|
return RecursiveExecutionTuple(ExecutionResult.PENDING, None, None)
|
||||||
caches.outputs.set(unique_id, output_data)
|
caches.outputs.set(unique_id, output_data)
|
||||||
except interruption.InterruptProcessingException as iex:
|
except interruption.InterruptProcessingException as iex:
|
||||||
logging.info("Processing interrupted")
|
logger.info("Processing interrupted")
|
||||||
|
|
||||||
# skip formatting inputs/outputs
|
# skip formatting inputs/outputs
|
||||||
error_details: RecursiveExecutionErrorDetailsInterrupted = {
|
error_details: RecursiveExecutionErrorDetailsInterrupted = {
|
||||||
@ -588,13 +589,13 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
|
|||||||
for name, inputs in input_data_all.items():
|
for name, inputs in input_data_all.items():
|
||||||
input_data_formatted[name] = [format_value(x) for x in inputs]
|
input_data_formatted[name] = [format_value(x) for x in inputs]
|
||||||
|
|
||||||
logging.error("An error occurred while executing a workflow", exc_info=ex)
|
logger.error("An error occurred while executing a workflow", exc_info=ex)
|
||||||
logging.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
tips = ""
|
tips = ""
|
||||||
|
|
||||||
if isinstance(ex, model_management.OOM_EXCEPTION):
|
if isinstance(ex, model_management.OOM_EXCEPTION):
|
||||||
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
|
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
|
||||||
logging.error("Got an OOM, unloading all loaded models.")
|
logger.error("Got an OOM, unloading all loaded models.")
|
||||||
model_management.unload_all_models()
|
model_management.unload_all_models()
|
||||||
|
|
||||||
error_details: RecursiveExecutionErrorDetails = {
|
error_details: RecursiveExecutionErrorDetails = {
|
||||||
@ -606,7 +607,7 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
|
|||||||
}
|
}
|
||||||
|
|
||||||
if should_panic_on_exception(ex, args.panic_when):
|
if should_panic_on_exception(ex, args.panic_when):
|
||||||
logging.error(f"The exception {ex} was configured as unrecoverable, scheduling an exit")
|
logger.error(f"The exception {ex} was configured as unrecoverable, scheduling an exit")
|
||||||
|
|
||||||
def sys_exit(*args):
|
def sys_exit(*args):
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
@ -1120,11 +1121,11 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty
|
|||||||
if valid is True:
|
if valid is True:
|
||||||
good_outputs.add(o)
|
good_outputs.add(o)
|
||||||
else:
|
else:
|
||||||
logging.error(f"Failed to validate prompt for output {o}:")
|
logger.error(f"Failed to validate prompt for output {o}:")
|
||||||
if len(reasons) > 0:
|
if len(reasons) > 0:
|
||||||
logging.error("* (prompt):")
|
logger.error("* (prompt):")
|
||||||
for reason in reasons:
|
for reason in reasons:
|
||||||
logging.error(f" - {reason['message']}: {reason['details']}")
|
logger.error(f" - {reason['message']}: {reason['details']}")
|
||||||
errors += [(o, reasons)]
|
errors += [(o, reasons)]
|
||||||
for node_id, result in validated.items():
|
for node_id, result in validated.items():
|
||||||
valid = result[0]
|
valid = result[0]
|
||||||
@ -1140,11 +1141,11 @@ async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, ty
|
|||||||
"dependent_outputs": [],
|
"dependent_outputs": [],
|
||||||
"class_type": class_type
|
"class_type": class_type
|
||||||
}
|
}
|
||||||
logging.error(f"* {class_type} {node_id}:")
|
logger.error(f"* {class_type} {node_id}:")
|
||||||
for reason in reasons:
|
for reason in reasons:
|
||||||
logging.error(f" - {reason['message']}: {reason['details']}")
|
logger.error(f" - {reason['message']}: {reason['details']}")
|
||||||
node_errors[node_id]["dependent_outputs"].append(o)
|
node_errors[node_id]["dependent_outputs"].append(o)
|
||||||
logging.error("Output will be ignored")
|
logger.error("Output will be ignored")
|
||||||
|
|
||||||
if len(good_outputs) == 0:
|
if len(good_outputs) == 0:
|
||||||
errors_list = []
|
errors_list = []
|
||||||
|
|||||||
@ -93,7 +93,7 @@ def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optio
|
|||||||
ModelPaths(["checkpoints"], supported_extensions=set(supported_pt_extensions)),
|
ModelPaths(["checkpoints"], supported_extensions=set(supported_pt_extensions)),
|
||||||
ModelPaths(["configs"], additional_absolute_directory_paths=[get_package_as_path("comfy.configs")], supported_extensions={".yaml"}),
|
ModelPaths(["configs"], additional_absolute_directory_paths=[get_package_as_path("comfy.configs")], supported_extensions={".yaml"}),
|
||||||
ModelPaths(["vae"], supported_extensions=set(supported_pt_extensions)),
|
ModelPaths(["vae"], supported_extensions=set(supported_pt_extensions)),
|
||||||
ModelPaths(folder_names=["clip", "text_encoders"], supported_extensions=set(supported_pt_extensions)),
|
ModelPaths(folder_names=["text_encoders", "clip"], supported_extensions=set(supported_pt_extensions)),
|
||||||
ModelPaths(["loras"], supported_extensions=set(supported_pt_extensions)),
|
ModelPaths(["loras"], supported_extensions=set(supported_pt_extensions)),
|
||||||
ModelPaths(folder_names=["diffusion_models", "unet"], supported_extensions=set(supported_pt_extensions), folder_names_are_relative_directory_paths_too=True),
|
ModelPaths(folder_names=["diffusion_models", "unet"], supported_extensions=set(supported_pt_extensions), folder_names_are_relative_directory_paths_too=True),
|
||||||
ModelPaths(["clip_vision"], supported_extensions=set(supported_pt_extensions)),
|
ModelPaths(["clip_vision"], supported_extensions=set(supported_pt_extensions)),
|
||||||
|
|||||||
@ -939,16 +939,16 @@ def get_torch_compiler_disable_decorator():
|
|||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
if not chained_hasattr(torch, "compiler.disable"):
|
if not chained_hasattr(torch, "compiler.disable"):
|
||||||
logger.info("ComfyUI-GGUF: Torch too old for torch.compile - bypassing")
|
logger.debug("ComfyUI-GGUF: Torch too old for torch.compile - bypassing")
|
||||||
return dummy_decorator # torch too old
|
return dummy_decorator # torch too old
|
||||||
elif version.parse(torch.__version__) >= version.parse("2.8"):
|
elif version.parse(torch.__version__) >= version.parse("2.8"):
|
||||||
logger.info("ComfyUI-GGUF: Allowing full torch compile")
|
logger.debug("ComfyUI-GGUF: Allowing full torch compile")
|
||||||
return dummy_decorator # torch compile works
|
return dummy_decorator # torch compile works
|
||||||
if chained_hasattr(torch, "_dynamo.config.nontraceable_tensor_subclasses"):
|
if chained_hasattr(torch, "_dynamo.config.nontraceable_tensor_subclasses"):
|
||||||
logger.info("ComfyUI-GGUF: Allowing full torch compile (nightly)")
|
logger.debug("ComfyUI-GGUF: Allowing full torch compile (nightly)")
|
||||||
return dummy_decorator # torch compile works, nightly before 2.8 release
|
return dummy_decorator # torch compile works, nightly before 2.8 release
|
||||||
else:
|
else:
|
||||||
logger.info("ComfyUI-GGUF: Partial torch compile only, consider updating pytorch")
|
logger.debug("ComfyUI-GGUF: Partial torch compile only, consider updating pytorch")
|
||||||
return torch.compiler.disable
|
return torch.compiler.disable
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -24,8 +24,10 @@ import torch
|
|||||||
from . import model_base
|
from . import model_base
|
||||||
from . import model_management
|
from . import model_management
|
||||||
from . import utils
|
from . import utils
|
||||||
from .lora_types import PatchDict, PatchOffset, PatchConversionFunction, PatchType, ModelPatchesDictValue
|
|
||||||
from . import weight_adapter
|
from . import weight_adapter
|
||||||
|
from .lora_types import PatchDict, PatchOffset, PatchConversionFunction, PatchType, ModelPatchesDictValue
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
LORA_CLIP_MAP = {
|
LORA_CLIP_MAP = {
|
||||||
"mlp.fc1": "mlp_fc1",
|
"mlp.fc1": "mlp_fc1",
|
||||||
@ -37,7 +39,7 @@ LORA_CLIP_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_lora(lora, to_load, log_missing=True) -> PatchDict:
|
def load_lora(lora, to_load, log_missing=True, lora_name=None) -> PatchDict:
|
||||||
patch_dict: PatchDict = {}
|
patch_dict: PatchDict = {}
|
||||||
loaded_keys = set()
|
loaded_keys = set()
|
||||||
for x in to_load:
|
for x in to_load:
|
||||||
@ -91,9 +93,10 @@ def load_lora(lora, to_load, log_missing=True) -> PatchDict:
|
|||||||
loaded_keys.add(set_weight_name)
|
loaded_keys.add(set_weight_name)
|
||||||
|
|
||||||
if log_missing:
|
if log_missing:
|
||||||
for x in lora.keys():
|
not_loaded_keys = [x for x in lora.keys() if x not in loaded_keys]
|
||||||
if x not in loaded_keys:
|
n_not_loaded_keys = len(not_loaded_keys)
|
||||||
logging.warning("lora key not loaded: {}".format(x))
|
if n_not_loaded_keys > 0:
|
||||||
|
logger.warning(f"[{lora_name}] lora keys not loaded ({n_not_loaded_keys} / {len(loaded_keys) + n_not_loaded_keys}): {not_loaded_keys}")
|
||||||
|
|
||||||
return patch_dict
|
return patch_dict
|
||||||
|
|
||||||
@ -293,12 +296,12 @@ def model_lora_keys_unet(model, key_map=None):
|
|||||||
if k.startswith("diffusion_model."):
|
if k.startswith("diffusion_model."):
|
||||||
if k.endswith(".weight"):
|
if k.endswith(".weight"):
|
||||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
|
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k # SimpleTuner lycoris format
|
||||||
key_map["transformer.{}".format(key_lora)] = k #SimpleTuner regular format
|
key_map["transformer.{}".format(key_lora)] = k # SimpleTuner regular format
|
||||||
|
|
||||||
if isinstance(model, model_base.ACEStep):
|
if isinstance(model, model_base.ACEStep):
|
||||||
for k in sdk:
|
for k in sdk:
|
||||||
if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official ACE step lora format
|
if k.startswith("diffusion_model.") and k.endswith(".weight"): # Official ACE step lora format
|
||||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||||
key_map["{}".format(key_lora)] = k
|
key_map["{}".format(key_lora)] = k
|
||||||
|
|
||||||
@ -364,7 +367,7 @@ def calculate_weight(patches: ModelPatchesDictValue, weight, key, intermediate_d
|
|||||||
if isinstance(v, weight_adapter.WeightAdapterBase):
|
if isinstance(v, weight_adapter.WeightAdapterBase):
|
||||||
output = v.calculate_weight(weight, key, strength, strength_model, offset, function, intermediate_dtype, original_weights)
|
output = v.calculate_weight(weight, key, strength, strength_model, offset, function, intermediate_dtype, original_weights)
|
||||||
if output is None:
|
if output is None:
|
||||||
logging.warning("Calculate Weight Failed: {} {}".format(v.name, key))
|
logger.warning("Calculate Weight Failed: {} {}".format(v.name, key))
|
||||||
else:
|
else:
|
||||||
weight = output
|
weight = output
|
||||||
if old_weight is not None:
|
if old_weight is not None:
|
||||||
@ -382,12 +385,12 @@ def calculate_weight(patches: ModelPatchesDictValue, weight, key, intermediate_d
|
|||||||
# An extra flag to pad the weight if the diff's shape is larger than the weight
|
# An extra flag to pad the weight if the diff's shape is larger than the weight
|
||||||
do_pad_weight = len(v) > 1 and v[1]['pad_weight']
|
do_pad_weight = len(v) > 1 and v[1]['pad_weight']
|
||||||
if do_pad_weight and diff.shape != weight.shape:
|
if do_pad_weight and diff.shape != weight.shape:
|
||||||
logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape))
|
logger.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape))
|
||||||
weight = pad_tensor_to_shape(weight, diff.shape)
|
weight = pad_tensor_to_shape(weight, diff.shape)
|
||||||
|
|
||||||
if strength != 0.0:
|
if strength != 0.0:
|
||||||
if diff.shape != weight.shape:
|
if diff.shape != weight.shape:
|
||||||
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
|
logger.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
|
||||||
else:
|
else:
|
||||||
weight += function(strength * model_management.cast_to_device(diff, weight.device, weight.dtype))
|
weight += function(strength * model_management.cast_to_device(diff, weight.device, weight.dtype))
|
||||||
elif patch_type == "set":
|
elif patch_type == "set":
|
||||||
@ -398,7 +401,7 @@ def calculate_weight(patches: ModelPatchesDictValue, weight, key, intermediate_d
|
|||||||
model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
|
model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
|
||||||
weight += function(strength * model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
|
weight += function(strength * model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
|
||||||
else:
|
else:
|
||||||
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
logger.warning("patch type not recognized {} {}".format(patch_type, key))
|
||||||
|
|
||||||
if old_weight is not None:
|
if old_weight is not None:
|
||||||
weight = old_weight
|
weight = old_weight
|
||||||
|
|||||||
@ -1,8 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from itertools import chain
|
|
||||||
from os.path import join
|
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
import operator
|
import operator
|
||||||
@ -10,6 +7,8 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
from collections.abc import Sequence, MutableSequence
|
from collections.abc import Sequence, MutableSequence
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
from itertools import chain
|
||||||
|
from os.path import join
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Final, Set
|
from typing import List, Optional, Final, Set
|
||||||
|
|
||||||
@ -66,28 +65,38 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
|
|||||||
filename = canonicalize_path(filename)
|
filename = canonicalize_path(filename)
|
||||||
path = folder_paths.get_full_path(folder_name, filename)
|
path = folder_paths.get_full_path(folder_name, filename)
|
||||||
|
|
||||||
|
candidate_str_match = False
|
||||||
|
candidate_filename_match = False
|
||||||
|
candidate_alternate_filenames_match = False
|
||||||
|
candidate_save_filename_match = False
|
||||||
if path is None and not args.disable_known_models:
|
if path is None and not args.disable_known_models:
|
||||||
try:
|
try:
|
||||||
# todo: should this be the first or last path?
|
# todo: should this be the first or last path?
|
||||||
this_model_directory = folder_paths.get_folder_paths(folder_name)[0]
|
this_model_directory = folder_paths.get_folder_paths(folder_name)[0]
|
||||||
known_file: Optional[HuggingFile | CivitFile] = None
|
known_file: Optional[HuggingFile | CivitFile] = None
|
||||||
for candidate in known_files:
|
for candidate in known_files:
|
||||||
if (canonicalize_path(str(candidate)) == filename
|
candidate_str_match = canonicalize_path(str(candidate)) == filename
|
||||||
or canonicalize_path(candidate.filename) == filename
|
candidate_filename_match = canonicalize_path(candidate.filename) == filename
|
||||||
or filename in list(map(canonicalize_path, candidate.alternate_filenames))
|
candidate_alternate_filenames_match = filename in list(map(canonicalize_path, candidate.alternate_filenames))
|
||||||
or filename == canonicalize_path(candidate.save_with_filename)):
|
candidate_save_filename_match = filename == canonicalize_path(candidate.save_with_filename)
|
||||||
|
if (candidate_str_match
|
||||||
|
or candidate_filename_match
|
||||||
|
or candidate_alternate_filenames_match
|
||||||
|
or candidate_save_filename_match):
|
||||||
known_file = candidate
|
known_file = candidate
|
||||||
break
|
break
|
||||||
if known_file is None:
|
if known_file is None:
|
||||||
|
logger.debug(f"get_or_download could not find {filename} in {folder_name}, known_files={known_files}")
|
||||||
return path
|
return path
|
||||||
with comfy_tqdm():
|
with comfy_tqdm():
|
||||||
if isinstance(known_file, HuggingFile):
|
if isinstance(known_file, HuggingFile):
|
||||||
|
symlinks_supported = are_symlinks_supported()
|
||||||
if known_file.save_with_filename is not None:
|
if known_file.save_with_filename is not None:
|
||||||
linked_filename = known_file.save_with_filename
|
linked_filename = known_file.save_with_filename
|
||||||
elif not known_file.force_save_in_repo_id and os.path.basename(known_file.filename) != known_file.filename:
|
elif not known_file.force_save_in_repo_id and os.path.basename(known_file.filename) != known_file.filename:
|
||||||
linked_filename = os.path.basename(known_file.filename)
|
linked_filename = os.path.basename(known_file.filename)
|
||||||
else:
|
else:
|
||||||
linked_filename = None
|
linked_filename = known_file.filename
|
||||||
|
|
||||||
if known_file.force_save_in_repo_id or linked_filename is not None and os.path.dirname(known_file.filename) == "":
|
if known_file.force_save_in_repo_id or linked_filename is not None and os.path.dirname(known_file.filename) == "":
|
||||||
# if the known file has an overridden linked name, save it into a repo_id sub directory
|
# if the known file has an overridden linked name, save it into a repo_id sub directory
|
||||||
@ -112,8 +121,6 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
|
|||||||
|
|
||||||
cache_hit = False
|
cache_hit = False
|
||||||
try:
|
try:
|
||||||
if not are_symlinks_supported():
|
|
||||||
raise PermissionError("no symlink support")
|
|
||||||
# always retrieve this from the cache if it already exists there
|
# always retrieve this from the cache if it already exists there
|
||||||
path = hf_hub_download(repo_id=known_file.repo_id,
|
path = hf_hub_download(repo_id=known_file.repo_id,
|
||||||
filename=known_file.filename,
|
filename=known_file.filename,
|
||||||
@ -121,19 +128,20 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
|
|||||||
revision=known_file.revision,
|
revision=known_file.revision,
|
||||||
local_files_only=True,
|
local_files_only=True,
|
||||||
)
|
)
|
||||||
logger.info(f"hf_hub_download cache hit for {known_file.repo_id}/{known_file.filename}")
|
logger.debug(f"hf_hub_download cache hit for {known_file.repo_id}/{known_file.filename}")
|
||||||
if linked_filename is None:
|
|
||||||
linked_filename = known_file.filename
|
|
||||||
cache_hit = True
|
cache_hit = True
|
||||||
except (LocalEntryNotFoundError, PermissionError):
|
except LocalEntryNotFoundError:
|
||||||
path = hf_hub_download(repo_id=known_file.repo_id,
|
try:
|
||||||
filename=known_file.filename,
|
logger.debug(f"{folder_name}/{filename} is being downloaded from {known_file.repo_id}/{known_file.filename} candidate_str_match={candidate_str_match} candidate_filename_match={candidate_filename_match} candidate_alternate_filenames_match={candidate_alternate_filenames_match} candidate_save_filename_match={candidate_save_filename_match}")
|
||||||
local_dir=hf_destination_dir,
|
path = hf_hub_download(repo_id=known_file.repo_id,
|
||||||
repo_type=known_file.repo_type,
|
filename=known_file.filename,
|
||||||
revision=known_file.revision,
|
repo_type=known_file.repo_type,
|
||||||
)
|
revision=known_file.revision,
|
||||||
|
)
|
||||||
|
except IOError as exc_info:
|
||||||
|
logger.error(f"cannot reach huggingface {known_file.repo_id}/{known_file.filename}", exc_info=exc_info)
|
||||||
|
|
||||||
if known_file.convert_to_16_bit and file_size is not None and file_size != 0:
|
if path is not None and known_file.convert_to_16_bit and file_size is not None and file_size != 0:
|
||||||
tensors = {}
|
tensors = {}
|
||||||
with safe_open(path, framework="pt") as f:
|
with safe_open(path, framework="pt") as f:
|
||||||
with tqdm.tqdm(total=len(f.keys())) as pb:
|
with tqdm.tqdm(total=len(f.keys())) as pb:
|
||||||
@ -151,20 +159,23 @@ def get_or_download(folder_name: str, filename: str, known_files: Optional[List[
|
|||||||
logger.info(f"Converted {path} to 16 bit, size is {os.stat(path, follow_symlinks=True).st_size}")
|
logger.info(f"Converted {path} to 16 bit, size is {os.stat(path, follow_symlinks=True).st_size}")
|
||||||
|
|
||||||
link_successful = True
|
link_successful = True
|
||||||
if linked_filename is not None:
|
if path is not None:
|
||||||
destination_link = os.path.join(this_model_directory, linked_filename)
|
destination_link = os.path.join(this_model_directory, linked_filename)
|
||||||
try:
|
if Path(destination_link).is_file():
|
||||||
os.makedirs(this_model_directory, exist_ok=True)
|
logger.warning(f"{known_file.repo_id}/{known_file.filename} could not link to {destination_link} because the path already exists, which is unexpected")
|
||||||
os.symlink(path, destination_link)
|
else:
|
||||||
except Exception as exc_info:
|
|
||||||
logger.error("error while symbolic linking", exc_info=exc_info)
|
|
||||||
try:
|
try:
|
||||||
os.link(path, destination_link)
|
os.makedirs(this_model_directory, exist_ok=True)
|
||||||
except Exception as hard_link_exc:
|
os.symlink(path, destination_link)
|
||||||
logger.error("error while hard linking", exc_info=hard_link_exc)
|
except Exception as exc_info:
|
||||||
if cache_hit:
|
logger.error("error while symbolic linking", exc_info=exc_info)
|
||||||
shutil.copyfile(path, destination_link)
|
try:
|
||||||
link_successful = False
|
os.link(path, destination_link)
|
||||||
|
except Exception as hard_link_exc:
|
||||||
|
logger.error("error while hard linking", exc_info=hard_link_exc)
|
||||||
|
if cache_hit:
|
||||||
|
shutil.copyfile(path, destination_link)
|
||||||
|
link_successful = False
|
||||||
|
|
||||||
if not link_successful:
|
if not link_successful:
|
||||||
logger.error(f"Failed to link file with alternative download save name in a way that is compatible with Hugging Face caching {repr(known_file)}. If cache_hit={cache_hit} is True, the file was copied into the destination.", exc_info=exc_info)
|
logger.error(f"Failed to link file with alternative download save name in a way that is compatible with Hugging Face caching {repr(known_file)}. If cache_hit={cache_hit} is True, the file was copied into the destination.", exc_info=exc_info)
|
||||||
@ -558,7 +569,9 @@ KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([
|
|||||||
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_ti2v_5B_fp16.safetensors"),
|
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_ti2v_5B_fp16.safetensors"),
|
||||||
HuggingFile("lodestones/Chroma", "chroma-unlocked-v37.safetensors"),
|
HuggingFile("lodestones/Chroma", "chroma-unlocked-v37.safetensors"),
|
||||||
HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "HighNoise/Wan2.2-T2V-A14B-HighNoise-Q8_0.gguf"),
|
HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "HighNoise/Wan2.2-T2V-A14B-HighNoise-Q8_0.gguf"),
|
||||||
|
HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "HighNoise/Wan2.2-T2V-A14B-HighNoise-Q4_K_M.gguf"),
|
||||||
HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "LowNoise/Wan2.2-T2V-A14B-LowNoise-Q8_0.gguf"),
|
HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "LowNoise/Wan2.2-T2V-A14B-LowNoise-Q8_0.gguf"),
|
||||||
|
HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "LowNoise/Wan2.2-T2V-A14B-LowNoise-Q4_K_M.gguf"),
|
||||||
], folder_names=["diffusion_models", "unet"])
|
], folder_names=["diffusion_models", "unet"])
|
||||||
|
|
||||||
KNOWN_CLIP_MODELS: Final[KnownDownloadables] = KnownDownloadables([
|
KNOWN_CLIP_MODELS: Final[KnownDownloadables] = KnownDownloadables([
|
||||||
|
|||||||
@ -630,7 +630,7 @@ class ModelPatcher(ModelManageable):
|
|||||||
# from gguf
|
# from gguf
|
||||||
if is_quantized(weight):
|
if is_quantized(weight):
|
||||||
out_weight = weight.to(device_to)
|
out_weight = weight.to(device_to)
|
||||||
patches = move_patch_to_device(self.patches[key], self.load_device if self.patch_on_device else self.offload_device)
|
patches = move_patch_to_device(self.patches[key], self.load_device if self.gguf.patch_on_device else self.offload_device)
|
||||||
# TODO: do we ever have legitimate duplicate patches? (i.e. patch on top of patched weight)
|
# TODO: do we ever have legitimate duplicate patches? (i.e. patch on top of patched weight)
|
||||||
out_weight.patches = [(patches, key)]
|
out_weight.patches = [(patches, key)]
|
||||||
if inplace_update:
|
if inplace_update:
|
||||||
|
|||||||
@ -674,7 +674,7 @@ class LoraLoader:
|
|||||||
lora = utils.load_torch_file(lora_path, safe_load=True)
|
lora = utils.load_torch_file(lora_path, safe_load=True)
|
||||||
self.loaded_lora = (lora_path, lora)
|
self.loaded_lora = (lora_path, lora)
|
||||||
|
|
||||||
model_lora, clip_lora = sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
|
model_lora, clip_lora = sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_name=lora_name)
|
||||||
return (model_lora, clip_lora)
|
return (model_lora, clip_lora)
|
||||||
|
|
||||||
class LoraLoaderModelOnly(LoraLoader):
|
class LoraLoaderModelOnly(LoraLoader):
|
||||||
|
|||||||
14
comfy/sd.py
14
comfy/sd.py
@ -15,7 +15,6 @@ import yaml
|
|||||||
from . import clip_vision
|
from . import clip_vision
|
||||||
from . import diffusers_convert
|
from . import diffusers_convert
|
||||||
from . import gligen
|
from . import gligen
|
||||||
from . import lora
|
|
||||||
from . import model_detection
|
from . import model_detection
|
||||||
from . import model_management
|
from . import model_management
|
||||||
from . import model_patcher
|
from . import model_patcher
|
||||||
@ -37,6 +36,7 @@ from .ldm.lightricks.vae import causal_video_autoencoder as lightricks
|
|||||||
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
||||||
from .ldm.wan import vae as wan_vae
|
from .ldm.wan import vae as wan_vae
|
||||||
from .ldm.wan import vae2_2 as wan_vae2_2
|
from .ldm.wan import vae2_2 as wan_vae2_2
|
||||||
|
from .lora import load_lora, model_lora_keys_unet, model_lora_keys_clip
|
||||||
from .lora_convert import convert_lora
|
from .lora_convert import convert_lora
|
||||||
from .model_management import load_models_gpu
|
from .model_management import load_models_gpu
|
||||||
from .model_patcher import ModelPatcher
|
from .model_patcher import ModelPatcher
|
||||||
@ -64,15 +64,15 @@ from .utils import ProgressBar, FileMetadata
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
|
def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_name=None):
|
||||||
key_map = {}
|
key_map = {}
|
||||||
if model is not None:
|
if model is not None:
|
||||||
key_map = lora.model_lora_keys_unet(model.model, key_map)
|
key_map = model_lora_keys_unet(model.model, key_map)
|
||||||
if clip is not None:
|
if clip is not None:
|
||||||
key_map = lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
|
key_map = model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||||
|
|
||||||
_lora = convert_lora(_lora)
|
lora = convert_lora(lora)
|
||||||
loaded = lora.load_lora(_lora, key_map)
|
loaded = load_lora(lora, key_map, lora_name=lora_name)
|
||||||
if model is not None:
|
if model is not None:
|
||||||
new_modelpatcher: ModelPatcher = model.clone()
|
new_modelpatcher: ModelPatcher = model.clone()
|
||||||
k = new_modelpatcher.add_patches(loaded, strength_model)
|
k = new_modelpatcher.add_patches(loaded, strength_model)
|
||||||
@ -90,7 +90,7 @@ def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
|
|||||||
k1 = set(k1)
|
k1 = set(k1)
|
||||||
for x in loaded:
|
for x in loaded:
|
||||||
if (x not in k) and (x not in k1):
|
if (x not in k) and (x not in k1):
|
||||||
logger.warning("NOT LOADED {}".format(x))
|
logger.warning(f"[{lora_name}] clip keys not loaded {x}".format(x))
|
||||||
|
|
||||||
return (new_modelpatcher, new_clip)
|
return (new_modelpatcher, new_clip)
|
||||||
|
|
||||||
|
|||||||
@ -80,7 +80,6 @@ class TorchCompileModel(CustomNode):
|
|||||||
"backend": backend,
|
"backend": backend,
|
||||||
"mode": mode,
|
"mode": mode,
|
||||||
}
|
}
|
||||||
move_to_gpu = True
|
|
||||||
try:
|
try:
|
||||||
if backend == "torch_tensorrt":
|
if backend == "torch_tensorrt":
|
||||||
try:
|
try:
|
||||||
@ -98,7 +97,6 @@ class TorchCompileModel(CustomNode):
|
|||||||
"enable_weight_streaming": True,
|
"enable_weight_streaming": True,
|
||||||
"make_refittable": True,
|
"make_refittable": True,
|
||||||
}
|
}
|
||||||
move_to_gpu = True
|
|
||||||
del compile_kwargs["mode"]
|
del compile_kwargs["mode"]
|
||||||
if isinstance(model, (ModelPatcher, TransformersManagedModel, VAE)):
|
if isinstance(model, (ModelPatcher, TransformersManagedModel, VAE)):
|
||||||
to_return = model.clone()
|
to_return = model.clone()
|
||||||
@ -109,27 +107,18 @@ class TorchCompileModel(CustomNode):
|
|||||||
object_patches = ["encoder", "decoder"]
|
object_patches = ["encoder", "decoder"]
|
||||||
else:
|
else:
|
||||||
patcher = to_return
|
patcher = to_return
|
||||||
if object_patch is None or len(object_patches) == 0:
|
if object_patch is None or len(object_patches) == 0 or len(object_patches) == 1 and object_patches[0].strip() == "":
|
||||||
object_patches = [DIFFUSION_MODEL]
|
object_patches = [DIFFUSION_MODEL]
|
||||||
if move_to_gpu:
|
|
||||||
model_management.unload_all_models()
|
|
||||||
model_management.load_models_gpu([patcher])
|
|
||||||
set_torch_compile_wrapper(patcher, keys=object_patches, **compile_kwargs)
|
set_torch_compile_wrapper(patcher, keys=object_patches, **compile_kwargs)
|
||||||
# m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs))
|
|
||||||
# todo: do we want to move something back off the GPU?
|
|
||||||
# if move_to_gpu:
|
|
||||||
# model_management.unload_all_models()
|
|
||||||
return to_return,
|
return to_return,
|
||||||
elif isinstance(model, torch.nn.Module):
|
elif isinstance(model, torch.nn.Module):
|
||||||
if move_to_gpu:
|
model_management.unload_all_models()
|
||||||
model_management.unload_all_models()
|
model.to(device=model_management.get_torch_device())
|
||||||
model.to(device=model_management.get_torch_device())
|
|
||||||
res = torch.compile(model=model, **compile_kwargs),
|
res = torch.compile(model=model, **compile_kwargs),
|
||||||
if move_to_gpu:
|
model.to(device=model_management.unet_offload_device())
|
||||||
model.to(device=model_management.unet_offload_device())
|
|
||||||
return res,
|
return res,
|
||||||
else:
|
else:
|
||||||
logger.warning("Encountered a model that cannot be compiled")
|
logger.warning(f"Encountered a model {model} that cannot be compiled")
|
||||||
return model,
|
return model,
|
||||||
except OSError as os_error:
|
except OSError as os_error:
|
||||||
try:
|
try:
|
||||||
@ -174,7 +163,8 @@ class QuantizeModel(CustomNode):
|
|||||||
|
|
||||||
def execute(self, model: ModelPatcher, strategy: str = _QUANTIZATION_STRATEGIES[0]) -> tuple[ModelPatcher]:
|
def execute(self, model: ModelPatcher, strategy: str = _QUANTIZATION_STRATEGIES[0]) -> tuple[ModelPatcher]:
|
||||||
model = model.clone()
|
model = model.clone()
|
||||||
unet = model.get_model_object("diffusion_model")
|
model.patch_model(force_patch_weights=True)
|
||||||
|
unet = model.diffusion_model
|
||||||
# todo: quantize quantizes in place, which is not desired
|
# todo: quantize quantizes in place, which is not desired
|
||||||
|
|
||||||
# default exclusions
|
# default exclusions
|
||||||
@ -209,7 +199,7 @@ class QuantizeModel(CustomNode):
|
|||||||
if "autoquant" in strategy:
|
if "autoquant" in strategy:
|
||||||
_in_place_fixme = autoquant(unet, error_on_unseen=False)
|
_in_place_fixme = autoquant(unet, error_on_unseen=False)
|
||||||
else:
|
else:
|
||||||
quantize_(unet, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device())
|
quantize_(unet, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device(), filter_fn=filter)
|
||||||
_in_place_fixme = unet
|
_in_place_fixme = unet
|
||||||
unwrap_tensor_subclass(_in_place_fixme)
|
unwrap_tensor_subclass(_in_place_fixme)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -717,7 +717,7 @@ class LoraModelLoader:
|
|||||||
if strength_model == 0:
|
if strength_model == 0:
|
||||||
return (model, )
|
return (model, )
|
||||||
|
|
||||||
model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0)
|
model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0, None)
|
||||||
return (model_lora, )
|
return (model_lora, )
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user