Experimental GGUF support

This commit is contained in:
doctorpangloss 2025-07-28 17:02:20 -07:00
parent fe4057385d
commit 69a4906964
10 changed files with 1304 additions and 21 deletions

View File

@ -278,7 +278,7 @@ class PromptServer(ExecutorToClientProgress):
sid,
)
logger.info(
logger.debug(
f"Feature flags negotiated for client {sid}: {client_flags}"
)
first_message = False

View File

@ -13,7 +13,7 @@ from typing import Any, NamedTuple, Optional, Iterable
from .platform_path import construct_path
supported_pt_extensions = frozenset(['.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft' ".index.json"])
supported_pt_extensions = frozenset(['.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft' ".index.json", ".gguf"])
extension_mimetypes_cache = {
"webp": "image",
"fbx": "model",

1203
comfy/gguf.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,11 +1,13 @@
import json
import logging
import math
from typing import Optional
import torch
from . import supported_models, utils
from . import supported_models_base
from .gguf import GGMLOps
def count_blocks(state_dict_keys, prefix_string):
@ -620,7 +622,7 @@ def model_config_from_unet_config(unet_config, state_dict=None):
return None
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False, metadata=None):
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False, metadata:Optional[dict]=None):
unet_config = detect_unet_config(state_dict, unet_key_prefix, metadata=metadata)
if unet_config is None:
return None
@ -639,6 +641,9 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
else:
model_config.optimizations["fp8"] = True
if metadata is not None and "format" in metadata and metadata["format"] == "gguf":
model_config.custom_operations = GGMLOps
return model_config

View File

@ -557,6 +557,8 @@ KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_t2v_low_noise_14B_fp8_scaled.safetensors"),
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_ti2v_5B_fp16.safetensors"),
HuggingFile("lodestones/Chroma", "chroma-unlocked-v37.safetensors"),
HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "HighNoise/Wan2.2-T2V-A14B-HighNoise-Q8_0.gguf"),
HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "LowNoise/Wan2.2-T2V-A14B-LowNoise-Q8_0.gguf"),
], folder_names=["diffusion_models", "unet"])
KNOWN_CLIP_MODELS: Final[KnownDownloadables] = KnownDownloadables([

View File

@ -19,6 +19,7 @@ from __future__ import annotations
import collections
import copy
import dataclasses
import inspect
import logging
import typing
@ -37,6 +38,7 @@ from . import utils
from .comfy_types import UnetWrapperFunction
from .component_model.deprecation import _deprecate_method
from .float import stochastic_rounding
from .gguf import move_patch_to_device, is_torch_compatible, is_quantized, GGMLOps
from .hooks import EnumHookMode, _HookRef, HookGroup, EnumHookType, WeightHook, create_transformer_options_from_hooks
from .lora_types import PatchDict, PatchDictKey, PatchTuple, PatchWeightTuple, ModelPatchesDictValue
from .model_base import BaseModel
@ -221,6 +223,13 @@ class MemoryCounter:
self.value -= used
@dataclasses.dataclass
class GGUFQuantization:
loaded_from_gguf: bool = False
mmap_released: bool = False
patch_on_device: bool = False
class ModelPatcher(ModelManageable):
def __init__(self, model: BaseModel | torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None):
self.size = size
@ -257,6 +266,9 @@ class ModelPatcher(ModelManageable):
self.forced_hooks: Optional[HookGroup] = None # NOTE: only used for CLIP at this time
self.is_clip = False
self.hook_mode = EnumHookMode.MaxSpeed
self.gguf = GGUFQuantization()
if isinstance(model, BaseModel) and model.operations == GGMLOps:
self.gguf.loaded_from_gguf = True
@property
def model_options(self) -> ModelOptions:
@ -361,6 +373,9 @@ class ModelPatcher(ModelManageable):
n.forced_hooks = self.forced_hooks.clone() if self.forced_hooks else self.forced_hooks
n.is_clip = self.is_clip
n.hook_mode = self.hook_mode
n.gguf = copy.copy(self.gguf)
# todo: when is this set back to False? when would it make sense to?
n.gguf.mmap_released = False
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
callback(self, n)
@ -612,6 +627,19 @@ class ModelPatcher(ModelManageable):
weight, set_func, convert_func = get_key_weight(self.model, key)
inplace_update = self.weight_inplace_update or inplace_update
# from gguf
if is_quantized(weight):
out_weight = weight.to(device_to)
patches = move_patch_to_device(self.patches[key], self.load_device if self.patch_on_device else self.offload_device)
# TODO: do we ever have legitimate duplicate patches? (i.e. patch on top of patched weight)
out_weight.patches = [(patches, key)]
if inplace_update:
utils.copy_to_param(self.model, key, out_weight)
else:
utils.set_attr_param(self.model, key, out_weight)
return
# end gguf
if key not in self.backup:
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
@ -648,6 +676,9 @@ class ModelPatcher(ModelManageable):
return loading
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
if self.gguf.loaded_from_gguf:
force_patch_weights = True
with self.use_ejected():
self.unpatch_hooks()
mem_counter = 0
@ -744,6 +775,28 @@ class ModelPatcher(ModelManageable):
self.model.to(device_to)
mem_counter = self.model_size()
if self.gguf.loaded_from_gguf and not self.gguf.mmap_released:
# todo: when is mmap_released set to True?
linked = []
if lowvram_model_memory > 0:
for n, m in self.model.named_modules():
if hasattr(m, "weight"):
device = getattr(m.weight, "device", None)
if device == self.offload_device:
linked.append((n, m))
continue
if hasattr(m, "bias"):
device = getattr(m.bias, "device", None)
if device == self.offload_device:
linked.append((n, m))
continue
if linked and self.load_device != self.offload_device:
logger.info(f"gguf attempting to release mmap ({len(linked)})")
for n, m in linked:
# TODO: possible to OOM, find better way to detach
m.to(self.load_device).to(self.offload_device)
self.gguf.mmap_released = True
self._memory_measurements.lowvram_patch_counter += patch_counter
self.model_device = device_to
@ -774,6 +827,13 @@ class ModelPatcher(ModelManageable):
def unpatch_model(self, device_to=None, unpatch_weights=True):
self.eject_model()
if self.gguf.loaded_from_gguf and unpatch_weights:
for p in self.model.parameters():
if is_torch_compatible(p):
continue
patches = self.patches
if len(patches) > 0:
p.patches = []
if unpatch_weights:
self.unpatch_hooks()
if self._memory_measurements.model_lowvram:

View File

@ -503,7 +503,7 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_
weight_dtype == torch.float16 and
(compute_dtype == torch.float16 or compute_dtype is None)
):
logging.info("Using cublas ops")
logger.info("Using cublas ops")
return cublas_ops
if compute_dtype is None or weight_dtype == compute_dtype:

View File

@ -23,6 +23,7 @@ from . import model_sampling
from . import sd1_clip
from . import sdxl_clip
from . import utils
from .component_model.deprecation import _deprecate_method
from .hooks import EnumHookMode
from .ldm.ace.vae.music_dcae_pipeline import MusicDCAE
from .ldm.audio.autoencoder import AudioOobleckVAE
@ -58,7 +59,7 @@ from .text_encoders import sa_t5
from .text_encoders import sd2_clip
from .text_encoders import sd3_clip
from .text_encoders import wan
from .utils import ProgressBar
from .utils import ProgressBar, FileMetadata
logger = logging.getLogger(__name__)
@ -1098,7 +1099,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
return out
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options=None, te_model_options=None, metadata: Optional[str | dict] = None, ckpt_path=""):
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options=None, te_model_options=None, metadata: Optional[FileMetadata] = None, ckpt_path=""):
if te_model_options is None:
te_model_options = {}
if model_options is None:
@ -1182,7 +1183,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return (_model_patcher, clip, vae, clipvision)
def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: Optional[str] = ""): # load unet in diffusers or regular format
def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: Optional[str] = "", metadata: Optional[FileMetadata] = None): # load unet in diffusers or regular format
"""
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
@ -1192,6 +1193,7 @@ def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: O
- dtype: Override model data type
- custom_operations: Custom model operations
- fp8_optimizations: Enable FP8 optimizations
metadata: file metadata
Returns:
ModelPatcher: A wrapped model instance that handles device management and weight loading.
@ -1217,14 +1219,14 @@ def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: O
parameters = utils.calculate_parameters(sd)
weight_dtype = utils.weight_dtype(sd)
load_device = model_management.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, "")
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
if model_config is not None:
new_sd = sd
else:
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
if new_sd is not None: # diffusers mmdit
model_config = model_detection.model_config_from_unet(new_sd, "")
model_config = model_detection.model_config_from_unet(new_sd, "", metadata=metadata)
if model_config is None:
return None
else: # diffusers unet
@ -1269,21 +1271,21 @@ def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: O
def load_diffusion_model(unet_path, model_options: dict = None):
if model_options is None:
model_options = {}
sd = utils.load_torch_file(unet_path)
model = load_diffusion_model_state_dict(sd, model_options=model_options, ckpt_path=unet_path)
sd, metadata = utils.load_torch_file(unet_path, return_metadata=True)
model = load_diffusion_model_state_dict(sd, model_options=model_options, ckpt_path=unet_path, metadata=metadata)
if model is None:
logger.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
return model
@_deprecate_method(message="The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model", version="*")
def load_unet(unet_path, dtype=None):
logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
@_deprecate_method(message="The load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict", version="*")
def load_unet_state_dict(sd, dtype=None):
logging.warning("The load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict")
return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype})

View File

@ -31,7 +31,7 @@ import warnings
from contextlib import contextmanager
from pathlib import Path
from pickle import UnpicklingError
from typing import Optional, Any
from typing import Optional, Any, Literal
import numpy as np
import safetensors.torch
@ -40,15 +40,15 @@ from PIL import Image
from einops import rearrange
from torch.nn.functional import interpolate
from tqdm import tqdm
from typing_extensions import TypedDict, NotRequired
from comfy_api import feature_flags
from . import interruption, checkpoint_pickle
from .cli_args import args
from .component_model import files
from .component_model.deprecation import _deprecate_method
from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage
from .component_model.queue_types import BinaryEventTypes
from .execution_context import current_execution_context
from .gguf import gguf_sd_loader
from .progress import get_progress_state
MMAP_TORCH_FILES = args.mmap_torch_files
@ -89,12 +89,17 @@ def _get_progress_bar_enabled():
setattr(sys.modules[__name__], 'PROGRESS_BAR_ENABLED', property(_get_progress_bar_enabled))
def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=False):
class FileMetadata(TypedDict):
format: NotRequired[Literal["gguf"]]
def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=False) -> dict[str, torch.Tensor] | tuple[dict[str, torch.Tensor], Optional[FileMetadata]]:
if device is None:
device = torch.device("cpu")
if ckpt is None:
raise FileNotFoundError("the checkpoint was not found")
metadata = None
metadata: Optional[dict[str, str]] = None
sd: dict[str, torch.Tensor] = None
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
try:
with safetensors.safe_open(Path(ckpt).resolve(strict=True), framework="pt", device=device.type) as f:
@ -128,6 +133,10 @@ def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=Fal
sd: dict[str, torch.Tensor] = {}
for checkpoint_file in checkpoint_files:
sd.update(safetensors.torch.load_file(str(checkpoint_file), device=device.type))
elif ckpt.lower().endswith(".gguf"):
# from gguf
sd = gguf_sd_loader(ckpt)
metadata = {"format": "gguf"}
else:
try:
torch_args = {}
@ -138,7 +147,7 @@ def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=Fal
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
else:
logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle)
pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle)
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
@ -153,14 +162,14 @@ def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=Fal
try:
# wrong extension is most likely, try to load as safetensors anyway
sd = safetensors.torch.load_file(Path(ckpt).resolve(strict=True), device=device.type)
return sd
except Exception:
msg = f"The checkpoint at {ckpt} could not be loaded as a safetensor nor a torch checkpoint. The file at the path is corrupted or unexpected. Try deleting it and downloading it again"
if hasattr(exc_info, "add_note"):
exc_info.add_note(msg)
else:
logger.error(msg, exc_info=exc_info)
raise exc_info
if sd is None:
raise exc_info
return (sd, metadata) if return_metadata else sd
@ -1125,6 +1134,7 @@ def _progress_bar_update(value: float, total: float, preview_image_or_data: Opti
get_progress_state().update_progress(node_id, value, total, preview_image_or_data)
server.send_sync("progress", progress, client_id)
def set_progress_bar_enabled(enabled: bool):
warnings.warn(
"The global method 'set_progress_bar_enabled' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.",

View File

@ -107,6 +107,7 @@ dependencies = [
"setuptools",
"alembic",
"SQLAlchemy",
"gguf",
]
[build-system]