mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Experimental GGUF support
This commit is contained in:
parent
fe4057385d
commit
69a4906964
@ -278,7 +278,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
sid,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"Feature flags negotiated for client {sid}: {client_flags}"
|
||||
)
|
||||
first_message = False
|
||||
|
||||
@ -13,7 +13,7 @@ from typing import Any, NamedTuple, Optional, Iterable
|
||||
|
||||
from .platform_path import construct_path
|
||||
|
||||
supported_pt_extensions = frozenset(['.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft' ".index.json"])
|
||||
supported_pt_extensions = frozenset(['.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft' ".index.json", ".gguf"])
|
||||
extension_mimetypes_cache = {
|
||||
"webp": "image",
|
||||
"fbx": "model",
|
||||
|
||||
1203
comfy/gguf.py
Normal file
1203
comfy/gguf.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,11 +1,13 @@
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from . import supported_models, utils
|
||||
from . import supported_models_base
|
||||
from .gguf import GGMLOps
|
||||
|
||||
|
||||
def count_blocks(state_dict_keys, prefix_string):
|
||||
@ -620,7 +622,7 @@ def model_config_from_unet_config(unet_config, state_dict=None):
|
||||
return None
|
||||
|
||||
|
||||
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False, metadata=None):
|
||||
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False, metadata:Optional[dict]=None):
|
||||
unet_config = detect_unet_config(state_dict, unet_key_prefix, metadata=metadata)
|
||||
if unet_config is None:
|
||||
return None
|
||||
@ -639,6 +641,9 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
||||
else:
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
if metadata is not None and "format" in metadata and metadata["format"] == "gguf":
|
||||
model_config.custom_operations = GGMLOps
|
||||
|
||||
return model_config
|
||||
|
||||
|
||||
|
||||
@ -557,6 +557,8 @@ KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([
|
||||
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_t2v_low_noise_14B_fp8_scaled.safetensors"),
|
||||
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_ti2v_5B_fp16.safetensors"),
|
||||
HuggingFile("lodestones/Chroma", "chroma-unlocked-v37.safetensors"),
|
||||
HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "HighNoise/Wan2.2-T2V-A14B-HighNoise-Q8_0.gguf"),
|
||||
HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "LowNoise/Wan2.2-T2V-A14B-LowNoise-Q8_0.gguf"),
|
||||
], folder_names=["diffusion_models", "unet"])
|
||||
|
||||
KNOWN_CLIP_MODELS: Final[KnownDownloadables] = KnownDownloadables([
|
||||
|
||||
@ -19,6 +19,7 @@ from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import dataclasses
|
||||
import inspect
|
||||
import logging
|
||||
import typing
|
||||
@ -37,6 +38,7 @@ from . import utils
|
||||
from .comfy_types import UnetWrapperFunction
|
||||
from .component_model.deprecation import _deprecate_method
|
||||
from .float import stochastic_rounding
|
||||
from .gguf import move_patch_to_device, is_torch_compatible, is_quantized, GGMLOps
|
||||
from .hooks import EnumHookMode, _HookRef, HookGroup, EnumHookType, WeightHook, create_transformer_options_from_hooks
|
||||
from .lora_types import PatchDict, PatchDictKey, PatchTuple, PatchWeightTuple, ModelPatchesDictValue
|
||||
from .model_base import BaseModel
|
||||
@ -221,6 +223,13 @@ class MemoryCounter:
|
||||
self.value -= used
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GGUFQuantization:
|
||||
loaded_from_gguf: bool = False
|
||||
mmap_released: bool = False
|
||||
patch_on_device: bool = False
|
||||
|
||||
|
||||
class ModelPatcher(ModelManageable):
|
||||
def __init__(self, model: BaseModel | torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None):
|
||||
self.size = size
|
||||
@ -257,6 +266,9 @@ class ModelPatcher(ModelManageable):
|
||||
self.forced_hooks: Optional[HookGroup] = None # NOTE: only used for CLIP at this time
|
||||
self.is_clip = False
|
||||
self.hook_mode = EnumHookMode.MaxSpeed
|
||||
self.gguf = GGUFQuantization()
|
||||
if isinstance(model, BaseModel) and model.operations == GGMLOps:
|
||||
self.gguf.loaded_from_gguf = True
|
||||
|
||||
@property
|
||||
def model_options(self) -> ModelOptions:
|
||||
@ -361,6 +373,9 @@ class ModelPatcher(ModelManageable):
|
||||
n.forced_hooks = self.forced_hooks.clone() if self.forced_hooks else self.forced_hooks
|
||||
n.is_clip = self.is_clip
|
||||
n.hook_mode = self.hook_mode
|
||||
n.gguf = copy.copy(self.gguf)
|
||||
# todo: when is this set back to False? when would it make sense to?
|
||||
n.gguf.mmap_released = False
|
||||
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
||||
callback(self, n)
|
||||
@ -612,6 +627,19 @@ class ModelPatcher(ModelManageable):
|
||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||
inplace_update = self.weight_inplace_update or inplace_update
|
||||
|
||||
# from gguf
|
||||
if is_quantized(weight):
|
||||
out_weight = weight.to(device_to)
|
||||
patches = move_patch_to_device(self.patches[key], self.load_device if self.patch_on_device else self.offload_device)
|
||||
# TODO: do we ever have legitimate duplicate patches? (i.e. patch on top of patched weight)
|
||||
out_weight.patches = [(patches, key)]
|
||||
if inplace_update:
|
||||
utils.copy_to_param(self.model, key, out_weight)
|
||||
else:
|
||||
utils.set_attr_param(self.model, key, out_weight)
|
||||
return
|
||||
# end gguf
|
||||
|
||||
if key not in self.backup:
|
||||
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
||||
|
||||
@ -648,6 +676,9 @@ class ModelPatcher(ModelManageable):
|
||||
return loading
|
||||
|
||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||
if self.gguf.loaded_from_gguf:
|
||||
force_patch_weights = True
|
||||
|
||||
with self.use_ejected():
|
||||
self.unpatch_hooks()
|
||||
mem_counter = 0
|
||||
@ -744,6 +775,28 @@ class ModelPatcher(ModelManageable):
|
||||
self.model.to(device_to)
|
||||
mem_counter = self.model_size()
|
||||
|
||||
if self.gguf.loaded_from_gguf and not self.gguf.mmap_released:
|
||||
# todo: when is mmap_released set to True?
|
||||
linked = []
|
||||
if lowvram_model_memory > 0:
|
||||
for n, m in self.model.named_modules():
|
||||
if hasattr(m, "weight"):
|
||||
device = getattr(m.weight, "device", None)
|
||||
if device == self.offload_device:
|
||||
linked.append((n, m))
|
||||
continue
|
||||
if hasattr(m, "bias"):
|
||||
device = getattr(m.bias, "device", None)
|
||||
if device == self.offload_device:
|
||||
linked.append((n, m))
|
||||
continue
|
||||
if linked and self.load_device != self.offload_device:
|
||||
logger.info(f"gguf attempting to release mmap ({len(linked)})")
|
||||
for n, m in linked:
|
||||
# TODO: possible to OOM, find better way to detach
|
||||
m.to(self.load_device).to(self.offload_device)
|
||||
self.gguf.mmap_released = True
|
||||
|
||||
self._memory_measurements.lowvram_patch_counter += patch_counter
|
||||
|
||||
self.model_device = device_to
|
||||
@ -774,6 +827,13 @@ class ModelPatcher(ModelManageable):
|
||||
|
||||
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
||||
self.eject_model()
|
||||
if self.gguf.loaded_from_gguf and unpatch_weights:
|
||||
for p in self.model.parameters():
|
||||
if is_torch_compatible(p):
|
||||
continue
|
||||
patches = self.patches
|
||||
if len(patches) > 0:
|
||||
p.patches = []
|
||||
if unpatch_weights:
|
||||
self.unpatch_hooks()
|
||||
if self._memory_measurements.model_lowvram:
|
||||
|
||||
@ -503,7 +503,7 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_
|
||||
weight_dtype == torch.float16 and
|
||||
(compute_dtype == torch.float16 or compute_dtype is None)
|
||||
):
|
||||
logging.info("Using cublas ops")
|
||||
logger.info("Using cublas ops")
|
||||
return cublas_ops
|
||||
|
||||
if compute_dtype is None or weight_dtype == compute_dtype:
|
||||
|
||||
20
comfy/sd.py
20
comfy/sd.py
@ -23,6 +23,7 @@ from . import model_sampling
|
||||
from . import sd1_clip
|
||||
from . import sdxl_clip
|
||||
from . import utils
|
||||
from .component_model.deprecation import _deprecate_method
|
||||
from .hooks import EnumHookMode
|
||||
from .ldm.ace.vae.music_dcae_pipeline import MusicDCAE
|
||||
from .ldm.audio.autoencoder import AudioOobleckVAE
|
||||
@ -58,7 +59,7 @@ from .text_encoders import sa_t5
|
||||
from .text_encoders import sd2_clip
|
||||
from .text_encoders import sd3_clip
|
||||
from .text_encoders import wan
|
||||
from .utils import ProgressBar
|
||||
from .utils import ProgressBar, FileMetadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -1098,7 +1099,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
return out
|
||||
|
||||
|
||||
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options=None, te_model_options=None, metadata: Optional[str | dict] = None, ckpt_path=""):
|
||||
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options=None, te_model_options=None, metadata: Optional[FileMetadata] = None, ckpt_path=""):
|
||||
if te_model_options is None:
|
||||
te_model_options = {}
|
||||
if model_options is None:
|
||||
@ -1182,7 +1183,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
return (_model_patcher, clip, vae, clipvision)
|
||||
|
||||
|
||||
def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: Optional[str] = ""): # load unet in diffusers or regular format
|
||||
def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: Optional[str] = "", metadata: Optional[FileMetadata] = None): # load unet in diffusers or regular format
|
||||
"""
|
||||
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
|
||||
|
||||
@ -1192,6 +1193,7 @@ def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: O
|
||||
- dtype: Override model data type
|
||||
- custom_operations: Custom model operations
|
||||
- fp8_optimizations: Enable FP8 optimizations
|
||||
metadata: file metadata
|
||||
|
||||
Returns:
|
||||
ModelPatcher: A wrapped model instance that handles device management and weight loading.
|
||||
@ -1217,14 +1219,14 @@ def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: O
|
||||
parameters = utils.calculate_parameters(sd)
|
||||
weight_dtype = utils.weight_dtype(sd)
|
||||
load_device = model_management.get_torch_device()
|
||||
model_config = model_detection.model_config_from_unet(sd, "")
|
||||
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
|
||||
|
||||
if model_config is not None:
|
||||
new_sd = sd
|
||||
else:
|
||||
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
|
||||
if new_sd is not None: # diffusers mmdit
|
||||
model_config = model_detection.model_config_from_unet(new_sd, "")
|
||||
model_config = model_detection.model_config_from_unet(new_sd, "", metadata=metadata)
|
||||
if model_config is None:
|
||||
return None
|
||||
else: # diffusers unet
|
||||
@ -1269,21 +1271,21 @@ def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: O
|
||||
def load_diffusion_model(unet_path, model_options: dict = None):
|
||||
if model_options is None:
|
||||
model_options = {}
|
||||
sd = utils.load_torch_file(unet_path)
|
||||
model = load_diffusion_model_state_dict(sd, model_options=model_options, ckpt_path=unet_path)
|
||||
sd, metadata = utils.load_torch_file(unet_path, return_metadata=True)
|
||||
model = load_diffusion_model_state_dict(sd, model_options=model_options, ckpt_path=unet_path, metadata=metadata)
|
||||
if model is None:
|
||||
logger.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
|
||||
return model
|
||||
|
||||
|
||||
@_deprecate_method(message="The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model", version="*")
|
||||
def load_unet(unet_path, dtype=None):
|
||||
logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
|
||||
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
|
||||
|
||||
|
||||
@_deprecate_method(message="The load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict", version="*")
|
||||
def load_unet_state_dict(sd, dtype=None):
|
||||
logging.warning("The load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict")
|
||||
return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype})
|
||||
|
||||
|
||||
|
||||
@ -31,7 +31,7 @@ import warnings
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from pickle import UnpicklingError
|
||||
from typing import Optional, Any
|
||||
from typing import Optional, Any, Literal
|
||||
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
@ -40,15 +40,15 @@ from PIL import Image
|
||||
from einops import rearrange
|
||||
from torch.nn.functional import interpolate
|
||||
from tqdm import tqdm
|
||||
from typing_extensions import TypedDict, NotRequired
|
||||
|
||||
from comfy_api import feature_flags
|
||||
from . import interruption, checkpoint_pickle
|
||||
from .cli_args import args
|
||||
from .component_model import files
|
||||
from .component_model.deprecation import _deprecate_method
|
||||
from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage
|
||||
from .component_model.queue_types import BinaryEventTypes
|
||||
from .execution_context import current_execution_context
|
||||
from .gguf import gguf_sd_loader
|
||||
from .progress import get_progress_state
|
||||
|
||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||
@ -89,12 +89,17 @@ def _get_progress_bar_enabled():
|
||||
setattr(sys.modules[__name__], 'PROGRESS_BAR_ENABLED', property(_get_progress_bar_enabled))
|
||||
|
||||
|
||||
def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=False):
|
||||
class FileMetadata(TypedDict):
|
||||
format: NotRequired[Literal["gguf"]]
|
||||
|
||||
|
||||
def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=False) -> dict[str, torch.Tensor] | tuple[dict[str, torch.Tensor], Optional[FileMetadata]]:
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
if ckpt is None:
|
||||
raise FileNotFoundError("the checkpoint was not found")
|
||||
metadata = None
|
||||
metadata: Optional[dict[str, str]] = None
|
||||
sd: dict[str, torch.Tensor] = None
|
||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||
try:
|
||||
with safetensors.safe_open(Path(ckpt).resolve(strict=True), framework="pt", device=device.type) as f:
|
||||
@ -128,6 +133,10 @@ def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=Fal
|
||||
sd: dict[str, torch.Tensor] = {}
|
||||
for checkpoint_file in checkpoint_files:
|
||||
sd.update(safetensors.torch.load_file(str(checkpoint_file), device=device.type))
|
||||
elif ckpt.lower().endswith(".gguf"):
|
||||
# from gguf
|
||||
sd = gguf_sd_loader(ckpt)
|
||||
metadata = {"format": "gguf"}
|
||||
else:
|
||||
try:
|
||||
torch_args = {}
|
||||
@ -138,7 +147,7 @@ def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=Fal
|
||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
|
||||
else:
|
||||
logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
|
||||
pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle)
|
||||
pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle)
|
||||
if "state_dict" in pl_sd:
|
||||
sd = pl_sd["state_dict"]
|
||||
else:
|
||||
@ -153,14 +162,14 @@ def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=Fal
|
||||
try:
|
||||
# wrong extension is most likely, try to load as safetensors anyway
|
||||
sd = safetensors.torch.load_file(Path(ckpt).resolve(strict=True), device=device.type)
|
||||
return sd
|
||||
except Exception:
|
||||
msg = f"The checkpoint at {ckpt} could not be loaded as a safetensor nor a torch checkpoint. The file at the path is corrupted or unexpected. Try deleting it and downloading it again"
|
||||
if hasattr(exc_info, "add_note"):
|
||||
exc_info.add_note(msg)
|
||||
else:
|
||||
logger.error(msg, exc_info=exc_info)
|
||||
raise exc_info
|
||||
if sd is None:
|
||||
raise exc_info
|
||||
return (sd, metadata) if return_metadata else sd
|
||||
|
||||
|
||||
@ -1125,6 +1134,7 @@ def _progress_bar_update(value: float, total: float, preview_image_or_data: Opti
|
||||
get_progress_state().update_progress(node_id, value, total, preview_image_or_data)
|
||||
server.send_sync("progress", progress, client_id)
|
||||
|
||||
|
||||
def set_progress_bar_enabled(enabled: bool):
|
||||
warnings.warn(
|
||||
"The global method 'set_progress_bar_enabled' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.",
|
||||
|
||||
@ -107,6 +107,7 @@ dependencies = [
|
||||
"setuptools",
|
||||
"alembic",
|
||||
"SQLAlchemy",
|
||||
"gguf",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user