Experimental GGUF support

This commit is contained in:
doctorpangloss 2025-07-28 17:02:20 -07:00
parent fe4057385d
commit 69a4906964
10 changed files with 1304 additions and 21 deletions

View File

@ -278,7 +278,7 @@ class PromptServer(ExecutorToClientProgress):
sid, sid,
) )
logger.info( logger.debug(
f"Feature flags negotiated for client {sid}: {client_flags}" f"Feature flags negotiated for client {sid}: {client_flags}"
) )
first_message = False first_message = False

View File

@ -13,7 +13,7 @@ from typing import Any, NamedTuple, Optional, Iterable
from .platform_path import construct_path from .platform_path import construct_path
supported_pt_extensions = frozenset(['.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft' ".index.json"]) supported_pt_extensions = frozenset(['.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft' ".index.json", ".gguf"])
extension_mimetypes_cache = { extension_mimetypes_cache = {
"webp": "image", "webp": "image",
"fbx": "model", "fbx": "model",

1203
comfy/gguf.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,11 +1,13 @@
import json import json
import logging import logging
import math import math
from typing import Optional
import torch import torch
from . import supported_models, utils from . import supported_models, utils
from . import supported_models_base from . import supported_models_base
from .gguf import GGMLOps
def count_blocks(state_dict_keys, prefix_string): def count_blocks(state_dict_keys, prefix_string):
@ -620,7 +622,7 @@ def model_config_from_unet_config(unet_config, state_dict=None):
return None return None
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False, metadata=None): def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False, metadata:Optional[dict]=None):
unet_config = detect_unet_config(state_dict, unet_key_prefix, metadata=metadata) unet_config = detect_unet_config(state_dict, unet_key_prefix, metadata=metadata)
if unet_config is None: if unet_config is None:
return None return None
@ -639,6 +641,9 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
else: else:
model_config.optimizations["fp8"] = True model_config.optimizations["fp8"] = True
if metadata is not None and "format" in metadata and metadata["format"] == "gguf":
model_config.custom_operations = GGMLOps
return model_config return model_config

View File

@ -557,6 +557,8 @@ KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_t2v_low_noise_14B_fp8_scaled.safetensors"), HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_t2v_low_noise_14B_fp8_scaled.safetensors"),
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_ti2v_5B_fp16.safetensors"), HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_ti2v_5B_fp16.safetensors"),
HuggingFile("lodestones/Chroma", "chroma-unlocked-v37.safetensors"), HuggingFile("lodestones/Chroma", "chroma-unlocked-v37.safetensors"),
HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "HighNoise/Wan2.2-T2V-A14B-HighNoise-Q8_0.gguf"),
HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "LowNoise/Wan2.2-T2V-A14B-LowNoise-Q8_0.gguf"),
], folder_names=["diffusion_models", "unet"]) ], folder_names=["diffusion_models", "unet"])
KNOWN_CLIP_MODELS: Final[KnownDownloadables] = KnownDownloadables([ KNOWN_CLIP_MODELS: Final[KnownDownloadables] = KnownDownloadables([

View File

@ -19,6 +19,7 @@ from __future__ import annotations
import collections import collections
import copy import copy
import dataclasses
import inspect import inspect
import logging import logging
import typing import typing
@ -37,6 +38,7 @@ from . import utils
from .comfy_types import UnetWrapperFunction from .comfy_types import UnetWrapperFunction
from .component_model.deprecation import _deprecate_method from .component_model.deprecation import _deprecate_method
from .float import stochastic_rounding from .float import stochastic_rounding
from .gguf import move_patch_to_device, is_torch_compatible, is_quantized, GGMLOps
from .hooks import EnumHookMode, _HookRef, HookGroup, EnumHookType, WeightHook, create_transformer_options_from_hooks from .hooks import EnumHookMode, _HookRef, HookGroup, EnumHookType, WeightHook, create_transformer_options_from_hooks
from .lora_types import PatchDict, PatchDictKey, PatchTuple, PatchWeightTuple, ModelPatchesDictValue from .lora_types import PatchDict, PatchDictKey, PatchTuple, PatchWeightTuple, ModelPatchesDictValue
from .model_base import BaseModel from .model_base import BaseModel
@ -221,6 +223,13 @@ class MemoryCounter:
self.value -= used self.value -= used
@dataclasses.dataclass
class GGUFQuantization:
loaded_from_gguf: bool = False
mmap_released: bool = False
patch_on_device: bool = False
class ModelPatcher(ModelManageable): class ModelPatcher(ModelManageable):
def __init__(self, model: BaseModel | torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None): def __init__(self, model: BaseModel | torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None):
self.size = size self.size = size
@ -257,6 +266,9 @@ class ModelPatcher(ModelManageable):
self.forced_hooks: Optional[HookGroup] = None # NOTE: only used for CLIP at this time self.forced_hooks: Optional[HookGroup] = None # NOTE: only used for CLIP at this time
self.is_clip = False self.is_clip = False
self.hook_mode = EnumHookMode.MaxSpeed self.hook_mode = EnumHookMode.MaxSpeed
self.gguf = GGUFQuantization()
if isinstance(model, BaseModel) and model.operations == GGMLOps:
self.gguf.loaded_from_gguf = True
@property @property
def model_options(self) -> ModelOptions: def model_options(self) -> ModelOptions:
@ -361,6 +373,9 @@ class ModelPatcher(ModelManageable):
n.forced_hooks = self.forced_hooks.clone() if self.forced_hooks else self.forced_hooks n.forced_hooks = self.forced_hooks.clone() if self.forced_hooks else self.forced_hooks
n.is_clip = self.is_clip n.is_clip = self.is_clip
n.hook_mode = self.hook_mode n.hook_mode = self.hook_mode
n.gguf = copy.copy(self.gguf)
# todo: when is this set back to False? when would it make sense to?
n.gguf.mmap_released = False
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE): for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
callback(self, n) callback(self, n)
@ -612,6 +627,19 @@ class ModelPatcher(ModelManageable):
weight, set_func, convert_func = get_key_weight(self.model, key) weight, set_func, convert_func = get_key_weight(self.model, key)
inplace_update = self.weight_inplace_update or inplace_update inplace_update = self.weight_inplace_update or inplace_update
# from gguf
if is_quantized(weight):
out_weight = weight.to(device_to)
patches = move_patch_to_device(self.patches[key], self.load_device if self.patch_on_device else self.offload_device)
# TODO: do we ever have legitimate duplicate patches? (i.e. patch on top of patched weight)
out_weight.patches = [(patches, key)]
if inplace_update:
utils.copy_to_param(self.model, key, out_weight)
else:
utils.set_attr_param(self.model, key, out_weight)
return
# end gguf
if key not in self.backup: if key not in self.backup:
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update) self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
@ -648,6 +676,9 @@ class ModelPatcher(ModelManageable):
return loading return loading
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
if self.gguf.loaded_from_gguf:
force_patch_weights = True
with self.use_ejected(): with self.use_ejected():
self.unpatch_hooks() self.unpatch_hooks()
mem_counter = 0 mem_counter = 0
@ -744,6 +775,28 @@ class ModelPatcher(ModelManageable):
self.model.to(device_to) self.model.to(device_to)
mem_counter = self.model_size() mem_counter = self.model_size()
if self.gguf.loaded_from_gguf and not self.gguf.mmap_released:
# todo: when is mmap_released set to True?
linked = []
if lowvram_model_memory > 0:
for n, m in self.model.named_modules():
if hasattr(m, "weight"):
device = getattr(m.weight, "device", None)
if device == self.offload_device:
linked.append((n, m))
continue
if hasattr(m, "bias"):
device = getattr(m.bias, "device", None)
if device == self.offload_device:
linked.append((n, m))
continue
if linked and self.load_device != self.offload_device:
logger.info(f"gguf attempting to release mmap ({len(linked)})")
for n, m in linked:
# TODO: possible to OOM, find better way to detach
m.to(self.load_device).to(self.offload_device)
self.gguf.mmap_released = True
self._memory_measurements.lowvram_patch_counter += patch_counter self._memory_measurements.lowvram_patch_counter += patch_counter
self.model_device = device_to self.model_device = device_to
@ -774,6 +827,13 @@ class ModelPatcher(ModelManageable):
def unpatch_model(self, device_to=None, unpatch_weights=True): def unpatch_model(self, device_to=None, unpatch_weights=True):
self.eject_model() self.eject_model()
if self.gguf.loaded_from_gguf and unpatch_weights:
for p in self.model.parameters():
if is_torch_compatible(p):
continue
patches = self.patches
if len(patches) > 0:
p.patches = []
if unpatch_weights: if unpatch_weights:
self.unpatch_hooks() self.unpatch_hooks()
if self._memory_measurements.model_lowvram: if self._memory_measurements.model_lowvram:

View File

@ -503,7 +503,7 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_
weight_dtype == torch.float16 and weight_dtype == torch.float16 and
(compute_dtype == torch.float16 or compute_dtype is None) (compute_dtype == torch.float16 or compute_dtype is None)
): ):
logging.info("Using cublas ops") logger.info("Using cublas ops")
return cublas_ops return cublas_ops
if compute_dtype is None or weight_dtype == compute_dtype: if compute_dtype is None or weight_dtype == compute_dtype:

View File

@ -23,6 +23,7 @@ from . import model_sampling
from . import sd1_clip from . import sd1_clip
from . import sdxl_clip from . import sdxl_clip
from . import utils from . import utils
from .component_model.deprecation import _deprecate_method
from .hooks import EnumHookMode from .hooks import EnumHookMode
from .ldm.ace.vae.music_dcae_pipeline import MusicDCAE from .ldm.ace.vae.music_dcae_pipeline import MusicDCAE
from .ldm.audio.autoencoder import AudioOobleckVAE from .ldm.audio.autoencoder import AudioOobleckVAE
@ -58,7 +59,7 @@ from .text_encoders import sa_t5
from .text_encoders import sd2_clip from .text_encoders import sd2_clip
from .text_encoders import sd3_clip from .text_encoders import sd3_clip
from .text_encoders import wan from .text_encoders import wan
from .utils import ProgressBar from .utils import ProgressBar, FileMetadata
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -1098,7 +1099,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
return out return out
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options=None, te_model_options=None, metadata: Optional[str | dict] = None, ckpt_path=""): def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options=None, te_model_options=None, metadata: Optional[FileMetadata] = None, ckpt_path=""):
if te_model_options is None: if te_model_options is None:
te_model_options = {} te_model_options = {}
if model_options is None: if model_options is None:
@ -1182,7 +1183,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return (_model_patcher, clip, vae, clipvision) return (_model_patcher, clip, vae, clipvision)
def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: Optional[str] = ""): # load unet in diffusers or regular format def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: Optional[str] = "", metadata: Optional[FileMetadata] = None): # load unet in diffusers or regular format
""" """
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats. Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
@ -1192,6 +1193,7 @@ def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: O
- dtype: Override model data type - dtype: Override model data type
- custom_operations: Custom model operations - custom_operations: Custom model operations
- fp8_optimizations: Enable FP8 optimizations - fp8_optimizations: Enable FP8 optimizations
metadata: file metadata
Returns: Returns:
ModelPatcher: A wrapped model instance that handles device management and weight loading. ModelPatcher: A wrapped model instance that handles device management and weight loading.
@ -1217,14 +1219,14 @@ def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: O
parameters = utils.calculate_parameters(sd) parameters = utils.calculate_parameters(sd)
weight_dtype = utils.weight_dtype(sd) weight_dtype = utils.weight_dtype(sd)
load_device = model_management.get_torch_device() load_device = model_management.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, "") model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
if model_config is not None: if model_config is not None:
new_sd = sd new_sd = sd
else: else:
new_sd = model_detection.convert_diffusers_mmdit(sd, "") new_sd = model_detection.convert_diffusers_mmdit(sd, "")
if new_sd is not None: # diffusers mmdit if new_sd is not None: # diffusers mmdit
model_config = model_detection.model_config_from_unet(new_sd, "") model_config = model_detection.model_config_from_unet(new_sd, "", metadata=metadata)
if model_config is None: if model_config is None:
return None return None
else: # diffusers unet else: # diffusers unet
@ -1269,21 +1271,21 @@ def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: O
def load_diffusion_model(unet_path, model_options: dict = None): def load_diffusion_model(unet_path, model_options: dict = None):
if model_options is None: if model_options is None:
model_options = {} model_options = {}
sd = utils.load_torch_file(unet_path) sd, metadata = utils.load_torch_file(unet_path, return_metadata=True)
model = load_diffusion_model_state_dict(sd, model_options=model_options, ckpt_path=unet_path) model = load_diffusion_model_state_dict(sd, model_options=model_options, ckpt_path=unet_path, metadata=metadata)
if model is None: if model is None:
logger.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) logger.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd))) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
return model return model
@_deprecate_method(message="The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model", version="*")
def load_unet(unet_path, dtype=None): def load_unet(unet_path, dtype=None):
logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
return load_diffusion_model(unet_path, model_options={"dtype": dtype}) return load_diffusion_model(unet_path, model_options={"dtype": dtype})
@_deprecate_method(message="The load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict", version="*")
def load_unet_state_dict(sd, dtype=None): def load_unet_state_dict(sd, dtype=None):
logging.warning("The load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict")
return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype}) return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype})

View File

@ -31,7 +31,7 @@ import warnings
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from pickle import UnpicklingError from pickle import UnpicklingError
from typing import Optional, Any from typing import Optional, Any, Literal
import numpy as np import numpy as np
import safetensors.torch import safetensors.torch
@ -40,15 +40,15 @@ from PIL import Image
from einops import rearrange from einops import rearrange
from torch.nn.functional import interpolate from torch.nn.functional import interpolate
from tqdm import tqdm from tqdm import tqdm
from typing_extensions import TypedDict, NotRequired
from comfy_api import feature_flags
from . import interruption, checkpoint_pickle from . import interruption, checkpoint_pickle
from .cli_args import args from .cli_args import args
from .component_model import files from .component_model import files
from .component_model.deprecation import _deprecate_method from .component_model.deprecation import _deprecate_method
from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage
from .component_model.queue_types import BinaryEventTypes
from .execution_context import current_execution_context from .execution_context import current_execution_context
from .gguf import gguf_sd_loader
from .progress import get_progress_state from .progress import get_progress_state
MMAP_TORCH_FILES = args.mmap_torch_files MMAP_TORCH_FILES = args.mmap_torch_files
@ -89,12 +89,17 @@ def _get_progress_bar_enabled():
setattr(sys.modules[__name__], 'PROGRESS_BAR_ENABLED', property(_get_progress_bar_enabled)) setattr(sys.modules[__name__], 'PROGRESS_BAR_ENABLED', property(_get_progress_bar_enabled))
def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=False): class FileMetadata(TypedDict):
format: NotRequired[Literal["gguf"]]
def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=False) -> dict[str, torch.Tensor] | tuple[dict[str, torch.Tensor], Optional[FileMetadata]]:
if device is None: if device is None:
device = torch.device("cpu") device = torch.device("cpu")
if ckpt is None: if ckpt is None:
raise FileNotFoundError("the checkpoint was not found") raise FileNotFoundError("the checkpoint was not found")
metadata = None metadata: Optional[dict[str, str]] = None
sd: dict[str, torch.Tensor] = None
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
try: try:
with safetensors.safe_open(Path(ckpt).resolve(strict=True), framework="pt", device=device.type) as f: with safetensors.safe_open(Path(ckpt).resolve(strict=True), framework="pt", device=device.type) as f:
@ -128,6 +133,10 @@ def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=Fal
sd: dict[str, torch.Tensor] = {} sd: dict[str, torch.Tensor] = {}
for checkpoint_file in checkpoint_files: for checkpoint_file in checkpoint_files:
sd.update(safetensors.torch.load_file(str(checkpoint_file), device=device.type)) sd.update(safetensors.torch.load_file(str(checkpoint_file), device=device.type))
elif ckpt.lower().endswith(".gguf"):
# from gguf
sd = gguf_sd_loader(ckpt)
metadata = {"format": "gguf"}
else: else:
try: try:
torch_args = {} torch_args = {}
@ -138,7 +147,7 @@ def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=Fal
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args) pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
else: else:
logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt)) logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle) pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle)
if "state_dict" in pl_sd: if "state_dict" in pl_sd:
sd = pl_sd["state_dict"] sd = pl_sd["state_dict"]
else: else:
@ -153,14 +162,14 @@ def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=Fal
try: try:
# wrong extension is most likely, try to load as safetensors anyway # wrong extension is most likely, try to load as safetensors anyway
sd = safetensors.torch.load_file(Path(ckpt).resolve(strict=True), device=device.type) sd = safetensors.torch.load_file(Path(ckpt).resolve(strict=True), device=device.type)
return sd
except Exception: except Exception:
msg = f"The checkpoint at {ckpt} could not be loaded as a safetensor nor a torch checkpoint. The file at the path is corrupted or unexpected. Try deleting it and downloading it again" msg = f"The checkpoint at {ckpt} could not be loaded as a safetensor nor a torch checkpoint. The file at the path is corrupted or unexpected. Try deleting it and downloading it again"
if hasattr(exc_info, "add_note"): if hasattr(exc_info, "add_note"):
exc_info.add_note(msg) exc_info.add_note(msg)
else: else:
logger.error(msg, exc_info=exc_info) logger.error(msg, exc_info=exc_info)
raise exc_info if sd is None:
raise exc_info
return (sd, metadata) if return_metadata else sd return (sd, metadata) if return_metadata else sd
@ -1125,6 +1134,7 @@ def _progress_bar_update(value: float, total: float, preview_image_or_data: Opti
get_progress_state().update_progress(node_id, value, total, preview_image_or_data) get_progress_state().update_progress(node_id, value, total, preview_image_or_data)
server.send_sync("progress", progress, client_id) server.send_sync("progress", progress, client_id)
def set_progress_bar_enabled(enabled: bool): def set_progress_bar_enabled(enabled: bool):
warnings.warn( warnings.warn(
"The global method 'set_progress_bar_enabled' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.", "The global method 'set_progress_bar_enabled' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.",

View File

@ -107,6 +107,7 @@ dependencies = [
"setuptools", "setuptools",
"alembic", "alembic",
"SQLAlchemy", "SQLAlchemy",
"gguf",
] ]
[build-system] [build-system]