mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-17 00:43:48 +08:00
Experimental GGUF support
This commit is contained in:
parent
fe4057385d
commit
69a4906964
@ -278,7 +278,7 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
sid,
|
sid,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"Feature flags negotiated for client {sid}: {client_flags}"
|
f"Feature flags negotiated for client {sid}: {client_flags}"
|
||||||
)
|
)
|
||||||
first_message = False
|
first_message = False
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from typing import Any, NamedTuple, Optional, Iterable
|
|||||||
|
|
||||||
from .platform_path import construct_path
|
from .platform_path import construct_path
|
||||||
|
|
||||||
supported_pt_extensions = frozenset(['.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft' ".index.json"])
|
supported_pt_extensions = frozenset(['.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft' ".index.json", ".gguf"])
|
||||||
extension_mimetypes_cache = {
|
extension_mimetypes_cache = {
|
||||||
"webp": "image",
|
"webp": "image",
|
||||||
"fbx": "model",
|
"fbx": "model",
|
||||||
|
|||||||
1203
comfy/gguf.py
Normal file
1203
comfy/gguf.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,11 +1,13 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from . import supported_models, utils
|
from . import supported_models, utils
|
||||||
from . import supported_models_base
|
from . import supported_models_base
|
||||||
|
from .gguf import GGMLOps
|
||||||
|
|
||||||
|
|
||||||
def count_blocks(state_dict_keys, prefix_string):
|
def count_blocks(state_dict_keys, prefix_string):
|
||||||
@ -620,7 +622,7 @@ def model_config_from_unet_config(unet_config, state_dict=None):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False, metadata=None):
|
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False, metadata:Optional[dict]=None):
|
||||||
unet_config = detect_unet_config(state_dict, unet_key_prefix, metadata=metadata)
|
unet_config = detect_unet_config(state_dict, unet_key_prefix, metadata=metadata)
|
||||||
if unet_config is None:
|
if unet_config is None:
|
||||||
return None
|
return None
|
||||||
@ -639,6 +641,9 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
|||||||
else:
|
else:
|
||||||
model_config.optimizations["fp8"] = True
|
model_config.optimizations["fp8"] = True
|
||||||
|
|
||||||
|
if metadata is not None and "format" in metadata and metadata["format"] == "gguf":
|
||||||
|
model_config.custom_operations = GGMLOps
|
||||||
|
|
||||||
return model_config
|
return model_config
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -557,6 +557,8 @@ KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([
|
|||||||
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_t2v_low_noise_14B_fp8_scaled.safetensors"),
|
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_t2v_low_noise_14B_fp8_scaled.safetensors"),
|
||||||
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_ti2v_5B_fp16.safetensors"),
|
HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_ti2v_5B_fp16.safetensors"),
|
||||||
HuggingFile("lodestones/Chroma", "chroma-unlocked-v37.safetensors"),
|
HuggingFile("lodestones/Chroma", "chroma-unlocked-v37.safetensors"),
|
||||||
|
HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "HighNoise/Wan2.2-T2V-A14B-HighNoise-Q8_0.gguf"),
|
||||||
|
HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "LowNoise/Wan2.2-T2V-A14B-LowNoise-Q8_0.gguf"),
|
||||||
], folder_names=["diffusion_models", "unet"])
|
], folder_names=["diffusion_models", "unet"])
|
||||||
|
|
||||||
KNOWN_CLIP_MODELS: Final[KnownDownloadables] = KnownDownloadables([
|
KNOWN_CLIP_MODELS: Final[KnownDownloadables] = KnownDownloadables([
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import collections
|
import collections
|
||||||
import copy
|
import copy
|
||||||
|
import dataclasses
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import typing
|
import typing
|
||||||
@ -37,6 +38,7 @@ from . import utils
|
|||||||
from .comfy_types import UnetWrapperFunction
|
from .comfy_types import UnetWrapperFunction
|
||||||
from .component_model.deprecation import _deprecate_method
|
from .component_model.deprecation import _deprecate_method
|
||||||
from .float import stochastic_rounding
|
from .float import stochastic_rounding
|
||||||
|
from .gguf import move_patch_to_device, is_torch_compatible, is_quantized, GGMLOps
|
||||||
from .hooks import EnumHookMode, _HookRef, HookGroup, EnumHookType, WeightHook, create_transformer_options_from_hooks
|
from .hooks import EnumHookMode, _HookRef, HookGroup, EnumHookType, WeightHook, create_transformer_options_from_hooks
|
||||||
from .lora_types import PatchDict, PatchDictKey, PatchTuple, PatchWeightTuple, ModelPatchesDictValue
|
from .lora_types import PatchDict, PatchDictKey, PatchTuple, PatchWeightTuple, ModelPatchesDictValue
|
||||||
from .model_base import BaseModel
|
from .model_base import BaseModel
|
||||||
@ -221,6 +223,13 @@ class MemoryCounter:
|
|||||||
self.value -= used
|
self.value -= used
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class GGUFQuantization:
|
||||||
|
loaded_from_gguf: bool = False
|
||||||
|
mmap_released: bool = False
|
||||||
|
patch_on_device: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ModelPatcher(ModelManageable):
|
class ModelPatcher(ModelManageable):
|
||||||
def __init__(self, model: BaseModel | torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None):
|
def __init__(self, model: BaseModel | torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None):
|
||||||
self.size = size
|
self.size = size
|
||||||
@ -257,6 +266,9 @@ class ModelPatcher(ModelManageable):
|
|||||||
self.forced_hooks: Optional[HookGroup] = None # NOTE: only used for CLIP at this time
|
self.forced_hooks: Optional[HookGroup] = None # NOTE: only used for CLIP at this time
|
||||||
self.is_clip = False
|
self.is_clip = False
|
||||||
self.hook_mode = EnumHookMode.MaxSpeed
|
self.hook_mode = EnumHookMode.MaxSpeed
|
||||||
|
self.gguf = GGUFQuantization()
|
||||||
|
if isinstance(model, BaseModel) and model.operations == GGMLOps:
|
||||||
|
self.gguf.loaded_from_gguf = True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_options(self) -> ModelOptions:
|
def model_options(self) -> ModelOptions:
|
||||||
@ -361,6 +373,9 @@ class ModelPatcher(ModelManageable):
|
|||||||
n.forced_hooks = self.forced_hooks.clone() if self.forced_hooks else self.forced_hooks
|
n.forced_hooks = self.forced_hooks.clone() if self.forced_hooks else self.forced_hooks
|
||||||
n.is_clip = self.is_clip
|
n.is_clip = self.is_clip
|
||||||
n.hook_mode = self.hook_mode
|
n.hook_mode = self.hook_mode
|
||||||
|
n.gguf = copy.copy(self.gguf)
|
||||||
|
# todo: when is this set back to False? when would it make sense to?
|
||||||
|
n.gguf.mmap_released = False
|
||||||
|
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
||||||
callback(self, n)
|
callback(self, n)
|
||||||
@ -612,6 +627,19 @@ class ModelPatcher(ModelManageable):
|
|||||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||||
inplace_update = self.weight_inplace_update or inplace_update
|
inplace_update = self.weight_inplace_update or inplace_update
|
||||||
|
|
||||||
|
# from gguf
|
||||||
|
if is_quantized(weight):
|
||||||
|
out_weight = weight.to(device_to)
|
||||||
|
patches = move_patch_to_device(self.patches[key], self.load_device if self.patch_on_device else self.offload_device)
|
||||||
|
# TODO: do we ever have legitimate duplicate patches? (i.e. patch on top of patched weight)
|
||||||
|
out_weight.patches = [(patches, key)]
|
||||||
|
if inplace_update:
|
||||||
|
utils.copy_to_param(self.model, key, out_weight)
|
||||||
|
else:
|
||||||
|
utils.set_attr_param(self.model, key, out_weight)
|
||||||
|
return
|
||||||
|
# end gguf
|
||||||
|
|
||||||
if key not in self.backup:
|
if key not in self.backup:
|
||||||
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
||||||
|
|
||||||
@ -648,6 +676,9 @@ class ModelPatcher(ModelManageable):
|
|||||||
return loading
|
return loading
|
||||||
|
|
||||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||||
|
if self.gguf.loaded_from_gguf:
|
||||||
|
force_patch_weights = True
|
||||||
|
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
self.unpatch_hooks()
|
self.unpatch_hooks()
|
||||||
mem_counter = 0
|
mem_counter = 0
|
||||||
@ -744,6 +775,28 @@ class ModelPatcher(ModelManageable):
|
|||||||
self.model.to(device_to)
|
self.model.to(device_to)
|
||||||
mem_counter = self.model_size()
|
mem_counter = self.model_size()
|
||||||
|
|
||||||
|
if self.gguf.loaded_from_gguf and not self.gguf.mmap_released:
|
||||||
|
# todo: when is mmap_released set to True?
|
||||||
|
linked = []
|
||||||
|
if lowvram_model_memory > 0:
|
||||||
|
for n, m in self.model.named_modules():
|
||||||
|
if hasattr(m, "weight"):
|
||||||
|
device = getattr(m.weight, "device", None)
|
||||||
|
if device == self.offload_device:
|
||||||
|
linked.append((n, m))
|
||||||
|
continue
|
||||||
|
if hasattr(m, "bias"):
|
||||||
|
device = getattr(m.bias, "device", None)
|
||||||
|
if device == self.offload_device:
|
||||||
|
linked.append((n, m))
|
||||||
|
continue
|
||||||
|
if linked and self.load_device != self.offload_device:
|
||||||
|
logger.info(f"gguf attempting to release mmap ({len(linked)})")
|
||||||
|
for n, m in linked:
|
||||||
|
# TODO: possible to OOM, find better way to detach
|
||||||
|
m.to(self.load_device).to(self.offload_device)
|
||||||
|
self.gguf.mmap_released = True
|
||||||
|
|
||||||
self._memory_measurements.lowvram_patch_counter += patch_counter
|
self._memory_measurements.lowvram_patch_counter += patch_counter
|
||||||
|
|
||||||
self.model_device = device_to
|
self.model_device = device_to
|
||||||
@ -774,6 +827,13 @@ class ModelPatcher(ModelManageable):
|
|||||||
|
|
||||||
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
||||||
self.eject_model()
|
self.eject_model()
|
||||||
|
if self.gguf.loaded_from_gguf and unpatch_weights:
|
||||||
|
for p in self.model.parameters():
|
||||||
|
if is_torch_compatible(p):
|
||||||
|
continue
|
||||||
|
patches = self.patches
|
||||||
|
if len(patches) > 0:
|
||||||
|
p.patches = []
|
||||||
if unpatch_weights:
|
if unpatch_weights:
|
||||||
self.unpatch_hooks()
|
self.unpatch_hooks()
|
||||||
if self._memory_measurements.model_lowvram:
|
if self._memory_measurements.model_lowvram:
|
||||||
|
|||||||
@ -503,7 +503,7 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_
|
|||||||
weight_dtype == torch.float16 and
|
weight_dtype == torch.float16 and
|
||||||
(compute_dtype == torch.float16 or compute_dtype is None)
|
(compute_dtype == torch.float16 or compute_dtype is None)
|
||||||
):
|
):
|
||||||
logging.info("Using cublas ops")
|
logger.info("Using cublas ops")
|
||||||
return cublas_ops
|
return cublas_ops
|
||||||
|
|
||||||
if compute_dtype is None or weight_dtype == compute_dtype:
|
if compute_dtype is None or weight_dtype == compute_dtype:
|
||||||
|
|||||||
20
comfy/sd.py
20
comfy/sd.py
@ -23,6 +23,7 @@ from . import model_sampling
|
|||||||
from . import sd1_clip
|
from . import sd1_clip
|
||||||
from . import sdxl_clip
|
from . import sdxl_clip
|
||||||
from . import utils
|
from . import utils
|
||||||
|
from .component_model.deprecation import _deprecate_method
|
||||||
from .hooks import EnumHookMode
|
from .hooks import EnumHookMode
|
||||||
from .ldm.ace.vae.music_dcae_pipeline import MusicDCAE
|
from .ldm.ace.vae.music_dcae_pipeline import MusicDCAE
|
||||||
from .ldm.audio.autoencoder import AudioOobleckVAE
|
from .ldm.audio.autoencoder import AudioOobleckVAE
|
||||||
@ -58,7 +59,7 @@ from .text_encoders import sa_t5
|
|||||||
from .text_encoders import sd2_clip
|
from .text_encoders import sd2_clip
|
||||||
from .text_encoders import sd3_clip
|
from .text_encoders import sd3_clip
|
||||||
from .text_encoders import wan
|
from .text_encoders import wan
|
||||||
from .utils import ProgressBar
|
from .utils import ProgressBar, FileMetadata
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -1098,7 +1099,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options=None, te_model_options=None, metadata: Optional[str | dict] = None, ckpt_path=""):
|
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options=None, te_model_options=None, metadata: Optional[FileMetadata] = None, ckpt_path=""):
|
||||||
if te_model_options is None:
|
if te_model_options is None:
|
||||||
te_model_options = {}
|
te_model_options = {}
|
||||||
if model_options is None:
|
if model_options is None:
|
||||||
@ -1182,7 +1183,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
return (_model_patcher, clip, vae, clipvision)
|
return (_model_patcher, clip, vae, clipvision)
|
||||||
|
|
||||||
|
|
||||||
def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: Optional[str] = ""): # load unet in diffusers or regular format
|
def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: Optional[str] = "", metadata: Optional[FileMetadata] = None): # load unet in diffusers or regular format
|
||||||
"""
|
"""
|
||||||
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
|
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
|
||||||
|
|
||||||
@ -1192,6 +1193,7 @@ def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: O
|
|||||||
- dtype: Override model data type
|
- dtype: Override model data type
|
||||||
- custom_operations: Custom model operations
|
- custom_operations: Custom model operations
|
||||||
- fp8_optimizations: Enable FP8 optimizations
|
- fp8_optimizations: Enable FP8 optimizations
|
||||||
|
metadata: file metadata
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ModelPatcher: A wrapped model instance that handles device management and weight loading.
|
ModelPatcher: A wrapped model instance that handles device management and weight loading.
|
||||||
@ -1217,14 +1219,14 @@ def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: O
|
|||||||
parameters = utils.calculate_parameters(sd)
|
parameters = utils.calculate_parameters(sd)
|
||||||
weight_dtype = utils.weight_dtype(sd)
|
weight_dtype = utils.weight_dtype(sd)
|
||||||
load_device = model_management.get_torch_device()
|
load_device = model_management.get_torch_device()
|
||||||
model_config = model_detection.model_config_from_unet(sd, "")
|
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
|
||||||
|
|
||||||
if model_config is not None:
|
if model_config is not None:
|
||||||
new_sd = sd
|
new_sd = sd
|
||||||
else:
|
else:
|
||||||
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
|
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
|
||||||
if new_sd is not None: # diffusers mmdit
|
if new_sd is not None: # diffusers mmdit
|
||||||
model_config = model_detection.model_config_from_unet(new_sd, "")
|
model_config = model_detection.model_config_from_unet(new_sd, "", metadata=metadata)
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
return None
|
return None
|
||||||
else: # diffusers unet
|
else: # diffusers unet
|
||||||
@ -1269,21 +1271,21 @@ def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: O
|
|||||||
def load_diffusion_model(unet_path, model_options: dict = None):
|
def load_diffusion_model(unet_path, model_options: dict = None):
|
||||||
if model_options is None:
|
if model_options is None:
|
||||||
model_options = {}
|
model_options = {}
|
||||||
sd = utils.load_torch_file(unet_path)
|
sd, metadata = utils.load_torch_file(unet_path, return_metadata=True)
|
||||||
model = load_diffusion_model_state_dict(sd, model_options=model_options, ckpt_path=unet_path)
|
model = load_diffusion_model_state_dict(sd, model_options=model_options, ckpt_path=unet_path, metadata=metadata)
|
||||||
if model is None:
|
if model is None:
|
||||||
logger.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
logger.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
|
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@_deprecate_method(message="The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model", version="*")
|
||||||
def load_unet(unet_path, dtype=None):
|
def load_unet(unet_path, dtype=None):
|
||||||
logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
|
|
||||||
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
|
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
|
||||||
|
|
||||||
|
|
||||||
|
@_deprecate_method(message="The load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict", version="*")
|
||||||
def load_unet_state_dict(sd, dtype=None):
|
def load_unet_state_dict(sd, dtype=None):
|
||||||
logging.warning("The load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict")
|
|
||||||
return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype})
|
return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype})
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -31,7 +31,7 @@ import warnings
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pickle import UnpicklingError
|
from pickle import UnpicklingError
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any, Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
@ -40,15 +40,15 @@ from PIL import Image
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch.nn.functional import interpolate
|
from torch.nn.functional import interpolate
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from typing_extensions import TypedDict, NotRequired
|
||||||
|
|
||||||
from comfy_api import feature_flags
|
|
||||||
from . import interruption, checkpoint_pickle
|
from . import interruption, checkpoint_pickle
|
||||||
from .cli_args import args
|
from .cli_args import args
|
||||||
from .component_model import files
|
from .component_model import files
|
||||||
from .component_model.deprecation import _deprecate_method
|
from .component_model.deprecation import _deprecate_method
|
||||||
from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage
|
from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage
|
||||||
from .component_model.queue_types import BinaryEventTypes
|
|
||||||
from .execution_context import current_execution_context
|
from .execution_context import current_execution_context
|
||||||
|
from .gguf import gguf_sd_loader
|
||||||
from .progress import get_progress_state
|
from .progress import get_progress_state
|
||||||
|
|
||||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||||
@ -89,12 +89,17 @@ def _get_progress_bar_enabled():
|
|||||||
setattr(sys.modules[__name__], 'PROGRESS_BAR_ENABLED', property(_get_progress_bar_enabled))
|
setattr(sys.modules[__name__], 'PROGRESS_BAR_ENABLED', property(_get_progress_bar_enabled))
|
||||||
|
|
||||||
|
|
||||||
def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=False):
|
class FileMetadata(TypedDict):
|
||||||
|
format: NotRequired[Literal["gguf"]]
|
||||||
|
|
||||||
|
|
||||||
|
def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=False) -> dict[str, torch.Tensor] | tuple[dict[str, torch.Tensor], Optional[FileMetadata]]:
|
||||||
if device is None:
|
if device is None:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if ckpt is None:
|
if ckpt is None:
|
||||||
raise FileNotFoundError("the checkpoint was not found")
|
raise FileNotFoundError("the checkpoint was not found")
|
||||||
metadata = None
|
metadata: Optional[dict[str, str]] = None
|
||||||
|
sd: dict[str, torch.Tensor] = None
|
||||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||||
try:
|
try:
|
||||||
with safetensors.safe_open(Path(ckpt).resolve(strict=True), framework="pt", device=device.type) as f:
|
with safetensors.safe_open(Path(ckpt).resolve(strict=True), framework="pt", device=device.type) as f:
|
||||||
@ -128,6 +133,10 @@ def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=Fal
|
|||||||
sd: dict[str, torch.Tensor] = {}
|
sd: dict[str, torch.Tensor] = {}
|
||||||
for checkpoint_file in checkpoint_files:
|
for checkpoint_file in checkpoint_files:
|
||||||
sd.update(safetensors.torch.load_file(str(checkpoint_file), device=device.type))
|
sd.update(safetensors.torch.load_file(str(checkpoint_file), device=device.type))
|
||||||
|
elif ckpt.lower().endswith(".gguf"):
|
||||||
|
# from gguf
|
||||||
|
sd = gguf_sd_loader(ckpt)
|
||||||
|
metadata = {"format": "gguf"}
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
torch_args = {}
|
torch_args = {}
|
||||||
@ -138,7 +147,7 @@ def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=Fal
|
|||||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
|
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
|
||||||
else:
|
else:
|
||||||
logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
|
logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
|
||||||
pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle)
|
pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle)
|
||||||
if "state_dict" in pl_sd:
|
if "state_dict" in pl_sd:
|
||||||
sd = pl_sd["state_dict"]
|
sd = pl_sd["state_dict"]
|
||||||
else:
|
else:
|
||||||
@ -153,14 +162,14 @@ def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=Fal
|
|||||||
try:
|
try:
|
||||||
# wrong extension is most likely, try to load as safetensors anyway
|
# wrong extension is most likely, try to load as safetensors anyway
|
||||||
sd = safetensors.torch.load_file(Path(ckpt).resolve(strict=True), device=device.type)
|
sd = safetensors.torch.load_file(Path(ckpt).resolve(strict=True), device=device.type)
|
||||||
return sd
|
|
||||||
except Exception:
|
except Exception:
|
||||||
msg = f"The checkpoint at {ckpt} could not be loaded as a safetensor nor a torch checkpoint. The file at the path is corrupted or unexpected. Try deleting it and downloading it again"
|
msg = f"The checkpoint at {ckpt} could not be loaded as a safetensor nor a torch checkpoint. The file at the path is corrupted or unexpected. Try deleting it and downloading it again"
|
||||||
if hasattr(exc_info, "add_note"):
|
if hasattr(exc_info, "add_note"):
|
||||||
exc_info.add_note(msg)
|
exc_info.add_note(msg)
|
||||||
else:
|
else:
|
||||||
logger.error(msg, exc_info=exc_info)
|
logger.error(msg, exc_info=exc_info)
|
||||||
raise exc_info
|
if sd is None:
|
||||||
|
raise exc_info
|
||||||
return (sd, metadata) if return_metadata else sd
|
return (sd, metadata) if return_metadata else sd
|
||||||
|
|
||||||
|
|
||||||
@ -1125,6 +1134,7 @@ def _progress_bar_update(value: float, total: float, preview_image_or_data: Opti
|
|||||||
get_progress_state().update_progress(node_id, value, total, preview_image_or_data)
|
get_progress_state().update_progress(node_id, value, total, preview_image_or_data)
|
||||||
server.send_sync("progress", progress, client_id)
|
server.send_sync("progress", progress, client_id)
|
||||||
|
|
||||||
|
|
||||||
def set_progress_bar_enabled(enabled: bool):
|
def set_progress_bar_enabled(enabled: bool):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The global method 'set_progress_bar_enabled' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.",
|
"The global method 'set_progress_bar_enabled' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.",
|
||||||
|
|||||||
@ -107,6 +107,7 @@ dependencies = [
|
|||||||
"setuptools",
|
"setuptools",
|
||||||
"alembic",
|
"alembic",
|
||||||
"SQLAlchemy",
|
"SQLAlchemy",
|
||||||
|
"gguf",
|
||||||
]
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user