diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index b456eae17..1c044edbd 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -278,7 +278,7 @@ class PromptServer(ExecutorToClientProgress): sid, ) - logger.info( + logger.debug( f"Feature flags negotiated for client {sid}: {client_flags}" ) first_message = False diff --git a/comfy/component_model/folder_path_types.py b/comfy/component_model/folder_path_types.py index 42932e217..cba2da9ce 100644 --- a/comfy/component_model/folder_path_types.py +++ b/comfy/component_model/folder_path_types.py @@ -13,7 +13,7 @@ from typing import Any, NamedTuple, Optional, Iterable from .platform_path import construct_path -supported_pt_extensions = frozenset(['.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft' ".index.json"]) +supported_pt_extensions = frozenset(['.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft' ".index.json", ".gguf"]) extension_mimetypes_cache = { "webp": "image", "fbx": "model", diff --git a/comfy/gguf.py b/comfy/gguf.py new file mode 100644 index 000000000..d69567532 --- /dev/null +++ b/comfy/gguf.py @@ -0,0 +1,1203 @@ +""" +Copyright 2025 "City96" and Benjanin Berman + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import argparse +import logging +import os +import warnings + +import gguf +import torch +from safetensors.torch import load_file, save_file +from sentencepiece import sentencepiece_model_pb2 as model +from tqdm import tqdm + +from .lora import calculate_weight +from .model_management import device_supports_non_blocking +from .ops import cast_to, manual_cast + +logger = logging.getLogger(__name__) + +QUANTIZATION_THRESHOLD = 1024 +REARRANGE_THRESHOLD = 512 +MAX_TENSOR_NAME_LENGTH = 127 +MAX_TENSOR_DIMS = 4 +TORCH_COMPATIBLE_QTYPES = (None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16) +IMG_ARCH_LIST = {"flux", "sd1", "sdxl", "sd3", "aura", "hidream", "cosmos", "ltxv", "hyvid", "wan"} +TXT_ARCH_LIST = {"t5", "t5encoder", "llama"} + + +class ModelTemplate: + arch = "invalid" # string describing architecture + shape_fix = False # whether to reshape tensors + keys_detect = [] # list of lists to match in state dict + keys_banned = [] # list of keys that should mark model as invalid for conversion + keys_hiprec = [] # list of keys that need to be kept in fp32 for some reason + keys_ignore = [] # list of strings to ignore keys by when found + + def handle_nd_tensor(self, key, data): + raise NotImplementedError(f"Tensor detected that exceeds dims supported by C++ code! ({key} @ {data.shape})") + + +class ModelFlux(ModelTemplate): + arch = "flux" + keys_detect = [ + ("transformer_blocks.0.attn.norm_added_k.weight",), + ("double_blocks.0.img_attn.proj.weight",), + ] + keys_banned = ["transformer_blocks.0.attn.norm_added_k.weight", ] + + +class ModelSD3(ModelTemplate): + arch = "sd3" + keys_detect = [ + ("transformer_blocks.0.attn.add_q_proj.weight",), + ("joint_blocks.0.x_block.attn.qkv.weight",), + ] + keys_banned = ["transformer_blocks.0.attn.add_q_proj.weight", ] + + +class ModelAura(ModelTemplate): + arch = "aura" + keys_detect = [ + ("double_layers.3.modX.1.weight",), + ("joint_transformer_blocks.3.ff_context.out_projection.weight",), + ] + keys_banned = ["joint_transformer_blocks.3.ff_context.out_projection.weight", ] + + +class ModelHiDream(ModelTemplate): + arch = "hidream" + keys_detect = [ + ( + "caption_projection.0.linear.weight", + "double_stream_blocks.0.block.ff_i.shared_experts.w3.weight" + ) + ] + keys_hiprec = [ + # nn.parameter, can't load from BF16 ver + ".ff_i.gate.weight", + "img_emb.emb_pos" + ] + + +class CosmosPredict2(ModelTemplate): + arch = "cosmos" + keys_detect = [ + ( + "blocks.0.mlp.layer1.weight", + "blocks.0.adaln_modulation_cross_attn.1.weight", + ) + ] + keys_hiprec = ["pos_embedder"] + keys_ignore = ["_extra_state", "accum_"] + + +class ModelHyVid(ModelTemplate): + arch = "hyvid" + keys_detect = [ + ( + "double_blocks.0.img_attn_proj.weight", + "txt_in.individual_token_refiner.blocks.1.self_attn_qkv.weight", + ) + ] + + def handle_nd_tensor(self, key, data): + # hacky but don't have any better ideas + path = f"./fix_5d_tensors_{self.arch}.safetensors" # TODO: somehow get a path here?? + if os.path.isfile(path): + raise RuntimeError(f"5D tensor fix file already exists! {path}") + fsd = {key: torch.from_numpy(data)} + tqdm.write(f"5D key found in state dict! Manual fix required! - {key} {data.shape}") + save_file(fsd, path) + + +class ModelWan(ModelHyVid): + arch = "wan" + keys_detect = [ + ( + "blocks.0.self_attn.norm_q.weight", + "text_embedding.2.weight", + "head.modulation", + ) + ] + keys_hiprec = [ + ".modulation" # nn.parameter, can't load from BF16 ver + ] + + +class ModelLTXV(ModelTemplate): + arch = "ltxv" + keys_detect = [ + ( + "adaln_single.emb.timestep_embedder.linear_2.weight", + "transformer_blocks.27.scale_shift_table", + "caption_projection.linear_2.weight", + ) + ] + keys_hiprec = [ + "scale_shift_table" # nn.parameter, can't load from BF16 base quant + ] + + +class ModelSDXL(ModelTemplate): + arch = "sdxl" + shape_fix = True + keys_detect = [ + ("down_blocks.0.downsamplers.0.conv.weight", "add_embedding.linear_1.weight",), + ( + "input_blocks.3.0.op.weight", "input_blocks.6.0.op.weight", + "output_blocks.2.2.conv.weight", "output_blocks.5.2.conv.weight", + ), # Non-diffusers + ("label_emb.0.0.weight",), + ] + + +class ModelSD1(ModelTemplate): + arch = "sd1" + shape_fix = True + keys_detect = [ + ("down_blocks.0.downsamplers.0.conv.weight",), + ( + "input_blocks.3.0.op.weight", "input_blocks.6.0.op.weight", "input_blocks.9.0.op.weight", + "output_blocks.2.1.conv.weight", "output_blocks.5.2.conv.weight", "output_blocks.8.2.conv.weight" + ), # Non-diffusers + ] + + +# The architectures are checked in order and the first successful match terminates the search. +arch_list = [ModelFlux, ModelSD3, ModelAura, ModelHiDream, CosmosPredict2, ModelLTXV, ModelHyVid, ModelWan, ModelSDXL, ModelSD1] + + +def is_model_arch(model, state_dict): + # check if model is correct + matched = False + invalid = False + for match_list in model.keys_detect: + if all(key in state_dict for key in match_list): + matched = True + invalid = any(key in state_dict for key in model.keys_banned) + break + assert not invalid, "Model architecture not allowed for conversion! (i.e. reference VS diffusers format)" + return matched + + +def detect_arch(state_dict): + model_arch = None + for arch in arch_list: + if is_model_arch(arch, state_dict): + model_arch = arch() + break + assert model_arch is not None, "Unknown model architecture!" + return model_arch + + +def parse_args(): + parser = argparse.ArgumentParser(description="Generate F16 GGUF files from single UNET") + parser.add_argument("--src", required=True, help="Source model ckpt file.") + parser.add_argument("--dst", help="Output unet gguf file.") + args = parser.parse_args() + + if not os.path.isfile(args.src): + parser.error("No input provided!") + + return args + + +def strip_prefix(state_dict): + # prefix for mixed state dict + prefix = None + for pfx in ["model.diffusion_model.", "model."]: + if any([x.startswith(pfx) for x in state_dict.keys()]): + prefix = pfx + break + + # prefix for uniform state dict + if prefix is None: + for pfx in ["net."]: + if all([x.startswith(pfx) for x in state_dict.keys()]): + prefix = pfx + break + + # strip prefix if found + if prefix is not None: + logger.info(f"State dict prefix found: '{prefix}'") + sd = {} + for k, v in state_dict.items(): + if prefix not in k: + continue + k = k.replace(prefix, "") + sd[k] = v + else: + logger.debug("State dict has no prefix") + sd = state_dict + + return sd + + +def load_state_dict(path): + if any(path.endswith(x) for x in [".ckpt", ".pt", ".bin", ".pth"]): + state_dict = torch.load(path, map_location="cpu", weights_only=True) + for subkey in ["model", "module"]: + if subkey in state_dict: + state_dict = state_dict[subkey] + break + if len(state_dict) < 20: + raise RuntimeError(f"pt subkey load failed: {state_dict.keys()}") + else: + state_dict = load_file(path) + + return strip_prefix(state_dict) + + +def handle_tensors(writer, state_dict, model_arch): + name_lengths = tuple(sorted( + ((key, len(key)) for key in state_dict.keys()), + key=lambda item: item[1], + reverse=True, + )) + if not name_lengths: + return + max_name_len = name_lengths[0][1] + if max_name_len > MAX_TENSOR_NAME_LENGTH: + bad_list = ", ".join(f"{key!r} ({namelen})" for key, namelen in name_lengths if namelen > MAX_TENSOR_NAME_LENGTH) + raise ValueError(f"Can only handle tensor names up to {MAX_TENSOR_NAME_LENGTH} characters. Tensors exceeding the limit: {bad_list}") + for key, data in tqdm(state_dict.items()): + old_dtype = data.dtype + + if any(x in key for x in model_arch.keys_ignore): + tqdm.write(f"Filtering ignored key: '{key}'") + continue + + if data.dtype == torch.bfloat16: + data = data.to(torch.float32).numpy() + # this is so we don't break torch 2.0.X + elif data.dtype in [getattr(torch, "float8_e4m3fn", "_invalid"), getattr(torch, "float8_e5m2", "_invalid")]: + data = data.to(torch.float16).numpy() + else: + data = data.numpy() + + n_dims = len(data.shape) + data_shape = data.shape + if old_dtype == torch.bfloat16: + data_qtype = gguf.GGMLQuantizationType.BF16 + # elif old_dtype == torch.float32: + # data_qtype = gguf.GGMLQuantizationType.F32 + else: + data_qtype = gguf.GGMLQuantizationType.F16 + + # The max no. of dimensions that can be handled by the quantization code is 4 + if len(data.shape) > MAX_TENSOR_DIMS: + model_arch.handle_nd_tensor(key, data) + continue # needs to be added back later + + # get number of parameters (AKA elements) in this tensor + n_params = 1 + for dim_size in data_shape: + n_params *= dim_size + + if old_dtype in (torch.float32, torch.bfloat16): + if n_dims == 1: + # one-dimensional tensors should be kept in F32 + # also speeds up inference due to not dequantizing + data_qtype = gguf.GGMLQuantizationType.F32 + + elif n_params <= QUANTIZATION_THRESHOLD: + # very small tensors + data_qtype = gguf.GGMLQuantizationType.F32 + + elif any(x in key for x in model_arch.keys_hiprec): + # tensors that require max precision + data_qtype = gguf.GGMLQuantizationType.F32 + + if (model_arch.shape_fix # NEVER reshape for models such as flux + and n_dims > 1 # Skip one-dimensional tensors + and n_params >= REARRANGE_THRESHOLD # Only rearrange tensors meeting the size requirement + and (n_params / 256).is_integer() # Rearranging only makes sense if total elements is divisible by 256 + and not (data.shape[-1] / 256).is_integer() # Only need to rearrange if the last dimension is not divisible by 256 + ): + orig_shape = data.shape + data = data.reshape(n_params // 256, 256) + writer.add_array(f"comfy.gguf.orig_shape.{key}", tuple(int(dim) for dim in orig_shape)) + + try: + data = gguf.quants.quantize(data, data_qtype) + except (AttributeError, gguf.QuantError) as e: + tqdm.write(f"falling back to F16: {e}") + data_qtype = gguf.GGMLQuantizationType.F16 + data = gguf.quants.quantize(data, data_qtype) + + new_name = key # do we need to rename? + + shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}" + tqdm.write(f"{f'%-{max_name_len + 4}s' % f'{new_name}'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") + + writer.add_tensor(new_name, data, raw_dtype=data_qtype) + + +def convert_file(path, dst_path=None, interact=True, overwrite=False): + # load & run model detection logic + state_dict = load_state_dict(path) + model_arch = detect_arch(state_dict) + logger.info(f"* Architecture detected from input: {model_arch.arch}") + + # detect & set dtype for output file + dtypes = [x.dtype for x in state_dict.values()] + dtypes = {x: dtypes.count(x) for x in set(dtypes)} + main_dtype = max(dtypes, key=dtypes.get) + + if main_dtype == torch.bfloat16: + ftype_name = "BF16" + ftype_gguf = gguf.LlamaFileType.MOSTLY_BF16 + # elif main_dtype == torch.float32: + # ftype_name = "F32" + # ftype_gguf = None + else: + ftype_name = "F16" + ftype_gguf = gguf.LlamaFileType.MOSTLY_F16 + + if dst_path is None: + dst_path = f"{os.path.splitext(path)[0]}-{ftype_name}.gguf" + elif "{ftype}" in dst_path: # lcpp logic + dst_path = dst_path.replace("{ftype}", ftype_name) + + if os.path.isfile(dst_path) and not overwrite: + if interact: + input("Output exists enter to continue or ctrl+c to abort!") + else: + raise OSError("Output exists and overwriting is disabled!") + + # handle actual file + writer = gguf.GGUFWriter(path=None, arch=model_arch.arch) + writer.add_quantization_version(gguf.GGML_QUANT_VERSION) + if ftype_gguf is not None: + writer.add_file_type(ftype_gguf) + + handle_tensors(writer, state_dict, model_arch) + writer.write_header_to_file(path=dst_path) + writer.write_kv_data_to_file() + writer.write_tensors_to_file(progress=True) + writer.close() + + fix = f"./fix_5d_tensors_{model_arch.arch}.safetensors" + if os.path.isfile(fix): + logger.warning(f"\n### Warning! Fix file found at '{fix}'") + logger.warning(" you most likely need to run 'fix_5d_tensors.py' after quantization.") + + return dst_path, model_arch + + +def is_torch_compatible(tensor): + return tensor is None or getattr(tensor, "tensor_type", None) in TORCH_COMPATIBLE_QTYPES + + +def is_quantized(tensor): + return not is_torch_compatible(tensor) + + +def dequantize_tensor(tensor, dtype=None, dequant_dtype=None): + qtype = getattr(tensor, "tensor_type", None) + oshape = getattr(tensor, "tensor_shape", tensor.shape) + + if qtype in TORCH_COMPATIBLE_QTYPES: + return tensor.to(dtype) + elif qtype in dequantize_functions: + dequant_dtype = dtype if dequant_dtype == "target" else dequant_dtype + return dequantize(tensor.data, qtype, oshape, dtype=dequant_dtype).to(dtype) + else: + # this is incredibly slow + tqdm.write(f"Falling back to numpy dequant for qtype: {qtype}") + new = gguf.quants.dequantize(tensor.cpu().numpy(), qtype) + return torch.from_numpy(new).to(tensor.device, dtype=dtype) + + +def dequantize(data, qtype, oshape, dtype=None): + """ + Dequantize tensor back to usable shape/dtype + """ + block_size, type_size = gguf.GGML_QUANT_SIZES[qtype] + dequantize_blocks = dequantize_functions[qtype] + + rows = data.reshape( + (-1, data.shape[-1]) + ).view(torch.uint8) + + n_blocks = rows.numel() // type_size + blocks = rows.reshape((n_blocks, type_size)) + blocks = dequantize_blocks(blocks, block_size, type_size, dtype) + return blocks.reshape(oshape) + + +def to_uint32(x): + # no uint32 :( + x = x.view(torch.uint8).to(torch.int32) + return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1) + + +def split_block_dims(blocks, *args): + n_max = blocks.shape[1] + dims = list(args) + [n_max - sum(args)] + return torch.split(blocks, dims, dim=1) + + +# Full weights # +def dequantize_blocks_BF16(blocks, block_size, type_size, dtype=None): + return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32) + + +# Legacy Quants # +def dequantize_blocks_Q8_0(blocks, block_size, type_size, dtype=None): + d, x = split_block_dims(blocks, 2) + d = d.view(torch.float16).to(dtype) + x = x.view(torch.int8) + return (d * x) + + +def dequantize_blocks_Q5_1(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, m, qh, qs = split_block_dims(blocks, 2, 2, 4) + d = d.view(torch.float16).to(dtype) + m = m.view(torch.float16).to(dtype) + qh = to_uint32(qh) + + qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) + ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1) + qh = (qh & 1).to(torch.uint8) + ql = (ql & 0x0F).reshape((n_blocks, -1)) + + qs = (ql | (qh << 4)) + return (d * qs) + m + + +def dequantize_blocks_Q5_0(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, qh, qs = split_block_dims(blocks, 2, 4) + d = d.view(torch.float16).to(dtype) + qh = to_uint32(qh) + + qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) + ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1) + + qh = (qh & 1).to(torch.uint8) + ql = (ql & 0x0F).reshape(n_blocks, -1) + + qs = (ql | (qh << 4)).to(torch.int8) - 16 + return (d * qs) + + +def dequantize_blocks_Q4_1(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, m, qs = split_block_dims(blocks, 2, 2) + d = d.view(torch.float16).to(dtype) + m = m.view(torch.float16).to(dtype) + + qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1) + qs = (qs & 0x0F).reshape(n_blocks, -1) + + return (d * qs) + m + + +def dequantize_blocks_Q4_0(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, qs = split_block_dims(blocks, 2) + d = d.view(torch.float16).to(dtype) + + qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) + qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8 + return (d * qs) + + +# K Quants # +QK_K = 256 +K_SCALE_SIZE = 12 + + +def get_scale_min(scales): + n_blocks = scales.shape[0] + scales = scales.view(torch.uint8) + scales = scales.reshape((n_blocks, 3, 4)) + + d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2) + + sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1) + min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1) + + return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8))) + + +def dequantize_blocks_Q6_K(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + ql, qh, scales, d, = split_block_dims(blocks, QK_K // 2, QK_K // 4, QK_K // 16) + + scales = scales.view(torch.int8).to(dtype) + d = d.view(torch.float16).to(dtype) + d = (d * scales).reshape((n_blocks, QK_K // 16, 1)) + + ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) + ql = (ql & 0x0F).reshape((n_blocks, -1, 32)) + qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1)) + qh = (qh & 0x03).reshape((n_blocks, -1, 32)) + q = (ql | (qh << 4)).to(torch.int8) - 32 + q = q.reshape((n_blocks, QK_K // 16, -1)) + + return (d * q).reshape((n_blocks, QK_K)) + + +def dequantize_blocks_Q5_K(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, dmin, scales, qh, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE, QK_K // 8) + + d = d.view(torch.float16).to(dtype) + dmin = dmin.view(torch.float16).to(dtype) + + sc, m = get_scale_min(scales) + + d = (d * sc).reshape((n_blocks, -1, 1)) + dm = (dmin * m).reshape((n_blocks, -1, 1)) + + ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) + qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([i for i in range(8)], device=d.device, dtype=torch.uint8).reshape((1, 1, 8, 1)) + ql = (ql & 0x0F).reshape((n_blocks, -1, 32)) + qh = (qh & 0x01).reshape((n_blocks, -1, 32)) + q = (ql | (qh << 4)) + + return (d * q - dm).reshape((n_blocks, QK_K)) + + +def dequantize_blocks_Q4_K(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE) + d = d.view(torch.float16).to(dtype) + dmin = dmin.view(torch.float16).to(dtype) + + sc, m = get_scale_min(scales) + + d = (d * sc).reshape((n_blocks, -1, 1)) + dm = (dmin * m).reshape((n_blocks, -1, 1)) + + qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) + qs = (qs & 0x0F).reshape((n_blocks, -1, 32)) + + return (d * qs - dm).reshape((n_blocks, QK_K)) + + +def dequantize_blocks_Q3_K(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + hmask, qs, scales, d = split_block_dims(blocks, QK_K // 8, QK_K // 4, 12) + d = d.view(torch.float16).to(dtype) + + lscales, hscales = scales[:, :8], scales[:, 8:] + lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 2, 1)) + lscales = lscales.reshape((n_blocks, 16)) + hscales = hscales.reshape((n_blocks, 1, 4)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 4, 1)) + hscales = hscales.reshape((n_blocks, 16)) + scales = (lscales & 0x0F) | ((hscales & 0x03) << 4) + scales = (scales.to(torch.int8) - 32) + + dl = (d * scales).reshape((n_blocks, 16, 1)) + + ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1)) + qh = hmask.reshape(n_blocks, -1, 1, 32) >> torch.tensor([i for i in range(8)], device=d.device, dtype=torch.uint8).reshape((1, 1, 8, 1)) + ql = ql.reshape((n_blocks, 16, QK_K // 16)) & 3 + qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & 1) ^ 1 + q = (ql.to(torch.int8) - (qh << 2).to(torch.int8)) + + return (dl * q).reshape((n_blocks, QK_K)) + + +def dequantize_blocks_Q2_K(blocks, block_size, type_size, dtype=None): + n_blocks = blocks.shape[0] + + scales, qs, d, dmin = split_block_dims(blocks, QK_K // 16, QK_K // 4, 2) + d = d.view(torch.float16).to(dtype) + dmin = dmin.view(torch.float16).to(dtype) + + # (n_blocks, 16, 1) + dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1)) + ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1)) + + shift = torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1)) + + qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & 3 + qs = qs.reshape((n_blocks, QK_K // 16, 16)) + qs = dl * qs - ml + + return qs.reshape((n_blocks, -1)) + + +dequantize_functions = { + gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16, + gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0, + gguf.GGMLQuantizationType.Q5_1: dequantize_blocks_Q5_1, + gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0, + gguf.GGMLQuantizationType.Q4_1: dequantize_blocks_Q4_1, + gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0, + gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K, + gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K, + gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K, + gguf.GGMLQuantizationType.Q3_K: dequantize_blocks_Q3_K, + gguf.GGMLQuantizationType.Q2_K: dequantize_blocks_Q2_K, +} + + +# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) + + +def get_orig_shape(reader, tensor_name): + field_key = f"comfy.gguf.orig_shape.{tensor_name}" + field = reader.get_field(field_key) + if field is None: + return None + # Has original shape metadata, so we try to decode it. + if len(field.types) != 2 or field.types[0] != gguf.GGUFValueType.ARRAY or field.types[1] != gguf.GGUFValueType.INT32: + raise TypeError(f"Bad original shape metadata for {field_key}: Expected ARRAY of INT32, got {field.types}") + return torch.Size(tuple(int(field.parts[part_idx][0]) for part_idx in field.data)) + + +def get_field(reader, field_name, field_type): + field = reader.get_field(field_name) + if field is None: + return None + elif field_type == str: + # extra check here as this is used for checking arch string + if len(field.types) != 1 or field.types[0] != gguf.GGUFValueType.STRING: + raise TypeError(f"Bad type for GGUF {field_name} key: expected string, got {field.types!r}") + return str(field.parts[field.data[-1]], encoding="utf-8") + elif field_type in [int, float, bool]: + return field_type(field.parts[field.data[-1]]) + else: + raise TypeError(f"Unknown field type {field_type}") + + +def get_list_field(reader, field_name, field_type): + field = reader.get_field(field_name) + if field is None: + return None + elif field_type == str: + return tuple(str(field.parts[part_idx], encoding="utf-8") for part_idx in field.data) + elif field_type in [int, float, bool]: + return tuple(field_type(field.parts[part_idx][0]) for part_idx in field.data) + else: + raise TypeError(f"Unknown field type {field_type}") + + +def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=False, is_text_model=False) -> dict: + """ + Read state dict as fake tensors + """ + reader = gguf.GGUFReader(path) + + # filter and strip prefix + has_prefix = False + if handle_prefix is not None: + prefix_len = len(handle_prefix) + tensor_names = set(tensor.name for tensor in reader.tensors) + has_prefix = any(s.startswith(handle_prefix) for s in tensor_names) + + tensors = [] + for tensor in reader.tensors: + sd_key = tensor_name = tensor.name + if has_prefix: + if not tensor_name.startswith(handle_prefix): + continue + sd_key = tensor_name[prefix_len:] + tensors.append((sd_key, tensor)) + + # detect and verify architecture + compat = None + arch_str = get_field(reader, "general.architecture", str) + if arch_str in [None, "pig"]: + if is_text_model: + raise ValueError(f"This text model is incompatible with llama.cpp!\nConsider using the safetensors version\n({path})") + compat = "sd.cpp" if arch_str is None else arch_str + # import here to avoid changes to convert.py breaking regular models + try: + arch_str = detect_arch(set(val[0] for val in tensors)).arch + except Exception as e: + raise ValueError(f"This model is not currently supported - ({e})") + elif arch_str not in TXT_ARCH_LIST and is_text_model: + raise ValueError(f"Unexpected text model architecture type in GGUF file: {arch_str!r}") + elif arch_str not in IMG_ARCH_LIST and not is_text_model: + raise ValueError(f"Unexpected architecture type in GGUF file: {arch_str!r}") + + if compat: + logger.warning(f"Warning: This gguf model file is loaded in compatibility mode '{compat}' [arch:{arch_str}]") + + # main loading loop + state_dict = {} + qtype_dict = {} + for sd_key, tensor in tensors: + tensor_name = tensor.name + # torch_tensor = torch.from_numpy(tensor.data) # mmap + + # NOTE: line above replaced with this block to avoid persistent numpy warning about mmap + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="The given NumPy array is not writable") + torch_tensor = torch.from_numpy(tensor.data) # mmap + + shape = get_orig_shape(reader, tensor_name) + if shape is None: + shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape))) + # Workaround for stable-diffusion.cpp SDXL detection. + if compat == "sd.cpp" and arch_str == "sdxl": + if any([tensor_name.endswith(x) for x in (".proj_in.weight", ".proj_out.weight")]): + while len(shape) > 2 and shape[-1] == 1: + shape = shape[:-1] + + # add to state dict + if tensor.tensor_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}: + torch_tensor = torch_tensor.view(*shape) + state_dict[sd_key] = GGMLTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape) + + # keep track of loaded tensor types + tensor_type_str = getattr(tensor.tensor_type, "name", repr(tensor.tensor_type)) + qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1 + + # print loaded tensor type counts + logger.info("gguf qtypes: " + ", ".join(f"{k} ({v})" for k, v in qtype_dict.items())) + + # mark largest tensor for vram estimation + qsd = {k: v for k, v in state_dict.items() if is_quantized(v)} + if len(qsd) > 0: + max_key = max(qsd.keys(), key=lambda k: qsd[k].numel()) + state_dict[max_key].is_largest_weight = True + + if return_arch: + return (state_dict, arch_str) + return state_dict + + +# for remapping llama.cpp -> original key names +T5_SD_MAP = { + "enc.": "encoder.", + ".blk.": ".block.", + "token_embd": "shared", + "output_norm": "final_layer_norm", + "attn_q": "layer.0.SelfAttention.q", + "attn_k": "layer.0.SelfAttention.k", + "attn_v": "layer.0.SelfAttention.v", + "attn_o": "layer.0.SelfAttention.o", + "attn_norm": "layer.0.layer_norm", + "attn_rel_b": "layer.0.SelfAttention.relative_attention_bias", + "ffn_up": "layer.1.DenseReluDense.wi_1", + "ffn_down": "layer.1.DenseReluDense.wo", + "ffn_gate": "layer.1.DenseReluDense.wi_0", + "ffn_norm": "layer.1.layer_norm", +} + +LLAMA_SD_MAP = { + "blk.": "model.layers.", + "attn_norm": "input_layernorm", + "attn_q": "self_attn.q_proj", + "attn_k": "self_attn.k_proj", + "attn_v": "self_attn.v_proj", + "attn_output": "self_attn.o_proj", + "ffn_up": "mlp.up_proj", + "ffn_down": "mlp.down_proj", + "ffn_gate": "mlp.gate_proj", + "ffn_norm": "post_attention_layernorm", + "token_embd": "model.embed_tokens", + "output_norm": "model.norm", + "output.weight": "lm_head.weight", +} + + +def sd_map_replace(raw_sd, key_map): + sd = {} + for k, v in raw_sd.items(): + for s, d in key_map.items(): + k = k.replace(s, d) + sd[k] = v + return sd + + +def llama_permute(raw_sd, n_head, n_head_kv): + # Reverse version of LlamaModel.permute in llama.cpp convert script + sd = {} + permute = lambda x, h: x.reshape(h, x.shape[0] // h // 2, 2, *x.shape[1:]).swapaxes(1, 2).reshape(x.shape) + for k, v in raw_sd.items(): + if k.endswith(("q_proj.weight", "q_proj.bias")): + v.data = permute(v.data, n_head) + if k.endswith(("k_proj.weight", "k_proj.bias")): + v.data = permute(v.data, n_head_kv) + sd[k] = v + return sd + + +def gguf_tokenizer_loader(path, temb_shape): + # convert gguf tokenizer to spiece + logger.info("Attempting to recreate sentencepiece tokenizer from GGUF file metadata...") + spm = model.ModelProto() + + reader = gguf.GGUFReader(path) + + if get_field(reader, "tokenizer.ggml.model", str) == "t5": + if temb_shape == (256384, 4096): # probably UMT5 + spm.trainer_spec.model_type == 1 # Unigram (do we have a T5 w/ BPE?) + else: + raise NotImplementedError("Unknown model, can't set tokenizer!") + else: + raise NotImplementedError("Unknown model, can't set tokenizer!") + + spm.normalizer_spec.add_dummy_prefix = get_field(reader, "tokenizer.ggml.add_space_prefix", bool) + spm.normalizer_spec.remove_extra_whitespaces = get_field(reader, "tokenizer.ggml.remove_extra_whitespaces", bool) + + tokens = get_list_field(reader, "tokenizer.ggml.tokens", str) + scores = get_list_field(reader, "tokenizer.ggml.scores", float) + toktypes = get_list_field(reader, "tokenizer.ggml.token_type", int) + + for idx, (token, score, toktype) in enumerate(zip(tokens, scores, toktypes)): + # # These aren't present in the original? + # if toktype == 5 and idx >= temb_shape[0]%1000): + # continue + + piece = spm.SentencePiece() + piece.piece = token + piece.score = score + piece.type = toktype + spm.pieces.append(piece) + + # unsure if any of these are correct + spm.trainer_spec.byte_fallback = True + spm.trainer_spec.vocab_size = len(tokens) # split off unused? + spm.trainer_spec.max_sentence_length = 4096 + spm.trainer_spec.eos_id = get_field(reader, "tokenizer.ggml.eos_token_id", int) + spm.trainer_spec.pad_id = get_field(reader, "tokenizer.ggml.padding_token_id", int) + + logger.info(f"Created tokenizer with vocab size of {len(spm.pieces)}") + del reader + return torch.ByteTensor(list(spm.SerializeToString())) + + +def gguf_clip_loader(path): + sd, arch = gguf_sd_loader(path, return_arch=True, is_text_model=True) + if arch in {"t5", "t5encoder"}: + temb_key = "token_embd.weight" + if temb_key in sd and sd[temb_key].shape == (256384, 4096): + # non-standard Comfy-Org tokenizer + sd["spiece_model"] = gguf_tokenizer_loader(path, sd[temb_key].shape) + # TODO: dequantizing token embed here is janky but otherwise we OOM due to tensor being massive. + logger.warning(f"Dequantizing {temb_key} to prevent runtime OOM.") + sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16) + sd = sd_map_replace(sd, T5_SD_MAP) + elif arch in {"llama"}: + # TODO: pass model_options["vocab_size"] to loader somehow + temb_key = "token_embd.weight" + if temb_key in sd and sd[temb_key].shape[0] >= (64 * 1024): + # See note above for T5. + logger.warning(f"Dequantizing {temb_key} to prevent runtime OOM.") + sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16) + sd = sd_map_replace(sd, LLAMA_SD_MAP) + sd = llama_permute(sd, 32, 8) # L3 + else: + pass + return sd + + +# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) + + +def chained_hasattr(obj, chained_attr): + probe = obj + for attr in chained_attr.split('.'): + if hasattr(probe, attr): + probe = getattr(probe, attr) + else: + return False + return True + + +# A bakcward and forward compatible way to get `torch.compiler.disable`. +def get_torch_compiler_disable_decorator(): + def dummy_decorator(*args, **kwargs): + def noop(x): + return x + + return noop + + from packaging import version + + if not chained_hasattr(torch, "compiler.disable"): + logger.info("ComfyUI-GGUF: Torch too old for torch.compile - bypassing") + return dummy_decorator # torch too old + elif version.parse(torch.__version__) >= version.parse("2.8"): + logger.info("ComfyUI-GGUF: Allowing full torch compile") + return dummy_decorator # torch compile works + if chained_hasattr(torch, "_dynamo.config.nontraceable_tensor_subclasses"): + logger.info("ComfyUI-GGUF: Allowing full torch compile (nightly)") + return dummy_decorator # torch compile works, nightly before 2.8 release + else: + logger.info("ComfyUI-GGUF: Partial torch compile only, consider updating pytorch") + return torch.compiler.disable + + +torch_compiler_disable = get_torch_compiler_disable_decorator() + + +class GGMLTensor(torch.Tensor): + """ + Main tensor-like class for storing quantized weights + """ + + def __init__(self, *args, tensor_type, tensor_shape, patches=[], **kwargs): + super().__init__() + self.tensor_type = tensor_type + self.tensor_shape = tensor_shape + self.patches = patches + + def __new__(cls, *args, tensor_type, tensor_shape, patches=[], **kwargs): + return super().__new__(cls, *args, **kwargs) + + def to(self, *args, **kwargs): + new = super().to(*args, **kwargs) + new.tensor_type = getattr(self, "tensor_type", None) + new.tensor_shape = getattr(self, "tensor_shape", new.data.shape) + new.patches = getattr(self, "patches", []).copy() + return new + + def clone(self, *args, **kwargs): + return self + + def detach(self, *args, **kwargs): + return self + + def copy_(self, *args, **kwargs): + # fixes .weight.copy_ in comfy/clip_model/CLIPTextModel + try: + return super().copy_(*args, **kwargs) + except Exception as e: + logger.warning(f"ignoring 'copy_' on tensor: {e}") + + def new_empty(self, size, *args, **kwargs): + # Intel Arc fix, ref#50 + new_tensor = super().new_empty(size, *args, **kwargs) + return GGMLTensor( + new_tensor, + tensor_type=getattr(self, "tensor_type", None), + tensor_shape=size, + patches=getattr(self, "patches", []).copy() + ) + + @property + def shape(self): + if not hasattr(self, "tensor_shape"): + self.tensor_shape = self.size() + return self.tensor_shape + + +class GGMLLayer(torch.nn.Module): + """ + This (should) be responsible for de-quantizing on the fly + """ + comfy_cast_weights = True + dequant_dtype = None + patch_dtype = None + largest_layer = False + torch_compatible_tensor_types = {None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16} + + def is_ggml_quantized(self, *, weight=None, bias=None): + if weight is None: + weight = self.weight + if bias is None: + bias = self.bias + return is_quantized(weight) or is_quantized(bias) + + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + weight, bias = state_dict.get(f"{prefix}weight"), state_dict.get(f"{prefix}bias") + # NOTE: using modified load for linear due to not initializing on creation, see GGMLOps todo + if self.is_ggml_quantized(weight=weight, bias=bias) or isinstance(self, torch.nn.Linear): + return self.ggml_load_from_state_dict(state_dict, prefix, *args, **kwargs) + # Not strictly required, but fixes embedding shape mismatch. Threshold set in loader.py + if isinstance(self, torch.nn.Embedding) and self.weight.shape[0] >= (64 * 1024): + return self.ggml_load_from_state_dict(state_dict, prefix, *args, **kwargs) + return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + def ggml_load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + prefix_len = len(prefix) + for k, v in state_dict.items(): + if k[prefix_len:] == "weight": + self.weight = torch.nn.Parameter(v, requires_grad=False) + elif k[prefix_len:] == "bias" and v is not None: + self.bias = torch.nn.Parameter(v, requires_grad=False) + else: + unexpected_keys.append(k) + + # For Linear layer with missing weight + if self.weight is None and isinstance(self, torch.nn.Linear): + v = torch.zeros(self.in_features, self.out_features) + self.weight = torch.nn.Parameter(v, requires_grad=False) + missing_keys.append(prefix + "weight") + + # for vram estimation (TODO: less fragile logic?) + if getattr(self.weight, "is_largest_weight", False): + self.largest_layer = True + + def _save_to_state_dict(self, *args, **kwargs): + if self.is_ggml_quantized(): + return self.ggml_save_to_state_dict(*args, **kwargs) + return super()._save_to_state_dict(*args, **kwargs) + + def ggml_save_to_state_dict(self, destination, prefix, keep_vars): + # This is a fake state dict for vram estimation + weight = torch.zeros_like(self.weight, device=torch.device("meta")) + destination[prefix + "weight"] = weight + if self.bias is not None: + bias = torch.zeros_like(self.bias, device=torch.device("meta")) + destination[prefix + "bias"] = bias + + # Take into account space required for dequantizing the largest tensor + if self.largest_layer: + shape = getattr(self.weight, "tensor_shape", self.weight.shape) + dtype = self.dequant_dtype or torch.float16 + temp = torch.empty(*shape, device=torch.device("meta"), dtype=dtype) + destination[prefix + "temp.weight"] = temp + + return + # This would return the dequantized state dict + destination[prefix + "weight"] = self.get_weight(self.weight) + if bias is not None: + destination[prefix + "bias"] = self.get_weight(self.bias) + + def get_weight(self, tensor, dtype): + if tensor is None: + return + + # consolidate and load patches to GPU in async + patch_list = [] + device = tensor.device + for patches, key in getattr(tensor, "patches", []): + patch_list += move_patch_to_device(patches, device) + + # dequantize tensor while patches load + weight = dequantize_tensor(tensor, dtype, self.dequant_dtype) + + # prevent propagating custom tensor class + if isinstance(weight, GGMLTensor): + weight = torch.Tensor(weight) + + # apply patches + if len(patch_list) > 0: + if self.patch_dtype is None: + weight = calculate_weight(patch_list, weight, key) + else: + # for testing, may degrade image quality + patch_dtype = dtype if self.patch_dtype == "target" else self.patch_dtype + weight = calculate_weight(patch_list, weight, key, patch_dtype) + return weight + + @torch_compiler_disable() + def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): + if input is not None: + if dtype is None: + dtype = getattr(input, "dtype", torch.float32) + if bias_dtype is None: + bias_dtype = dtype + if device is None: + device = input.device + + bias = None + non_blocking = device_supports_non_blocking(device) + if s.bias is not None: + bias = s.get_weight(s.bias.to(device), dtype) + bias = cast_to(bias, bias_dtype, device, non_blocking=non_blocking, copy=False) + + weight = s.get_weight(s.weight.to(device), dtype) + weight = cast_to(weight, dtype, device, non_blocking=non_blocking, copy=False) + return weight, bias + + def forward_comfy_cast_weights(self, input, *args, **kwargs): + if self.is_ggml_quantized(): + out = self.forward_ggml_cast_weights(input, *args, **kwargs) + else: + out = super().forward_comfy_cast_weights(input, *args, **kwargs) + + # non-ggml forward might still propagate custom tensor class + if isinstance(out, GGMLTensor): + out = torch.Tensor(out) + return out + + def forward_ggml_cast_weights(self, input): + raise NotImplementedError + + +class GGMLOps(manual_cast): + """ + Dequantize weights on the fly before doing the compute + """ + + class Linear(GGMLLayer, manual_cast.Linear): + dequant_dtype = None + patch_dtype = None + + def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): + torch.nn.Module.__init__(self) + # TODO: better workaround for reserved memory spike on windows + # Issue is with `torch.empty` still reserving the full memory for the layer + # Windows doesn't over-commit memory so without this 24GB+ of pagefile is used + self.in_features = in_features + self.out_features = out_features + self.weight = None + self.bias = None + + def forward_ggml_cast_weights(self, input): + weight, bias = self.cast_bias_weight(input) + return torch.nn.functional.linear(input, weight, bias) + + class Conv2d(GGMLLayer, manual_cast.Conv2d): + def forward_ggml_cast_weights(self, input): + weight, bias = self.cast_bias_weight(input) + return self._conv_forward(input, weight, bias) + + class Embedding(GGMLLayer, manual_cast.Embedding): + def forward_ggml_cast_weights(self, input, out_dtype=None): + output_dtype = out_dtype + if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16: + out_dtype = None + weight, _bias = self.cast_bias_weight(self, device=input.device, dtype=out_dtype) + return torch.nn.functional.embedding( + input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse + ).to(dtype=output_dtype) + + class LayerNorm(GGMLLayer, manual_cast.LayerNorm): + def forward_ggml_cast_weights(self, input): + if self.weight is None: + return super().forward_comfy_cast_weights(input) + weight, bias = self.cast_bias_weight(input) + return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) + + class GroupNorm(GGMLLayer, manual_cast.GroupNorm): + def forward_ggml_cast_weights(self, input): + weight, bias = self.cast_bias_weight(input) + return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) + + +def move_patch_to_device(item, device): + if isinstance(item, torch.Tensor): + return item.to(device, non_blocking=True) + elif isinstance(item, tuple): + return tuple(move_patch_to_device(x, device) for x in item) + elif isinstance(item, list): + return [move_patch_to_device(x, device) for x in item] + else: + return item diff --git a/comfy/model_detection.py b/comfy/model_detection.py index da17d425c..c3d5d9388 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -1,11 +1,13 @@ import json import logging import math +from typing import Optional import torch from . import supported_models, utils from . import supported_models_base +from .gguf import GGMLOps def count_blocks(state_dict_keys, prefix_string): @@ -620,7 +622,7 @@ def model_config_from_unet_config(unet_config, state_dict=None): return None -def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False, metadata=None): +def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False, metadata:Optional[dict]=None): unet_config = detect_unet_config(state_dict, unet_key_prefix, metadata=metadata) if unet_config is None: return None @@ -639,6 +641,9 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal else: model_config.optimizations["fp8"] = True + if metadata is not None and "format" in metadata and metadata["format"] == "gguf": + model_config.custom_operations = GGMLOps + return model_config diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index 316cec02e..ca1ed34f4 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -557,6 +557,8 @@ KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([ HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_t2v_low_noise_14B_fp8_scaled.safetensors"), HuggingFile("Comfy-Org/Wan_2.2_ComfyUI_Repackaged", "split_files/diffusion_models/wan2.2_ti2v_5B_fp16.safetensors"), HuggingFile("lodestones/Chroma", "chroma-unlocked-v37.safetensors"), + HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "HighNoise/Wan2.2-T2V-A14B-HighNoise-Q8_0.gguf"), + HuggingFile("QuantStack/Wan2.2-T2V-A14B-GGUF", "LowNoise/Wan2.2-T2V-A14B-LowNoise-Q8_0.gguf"), ], folder_names=["diffusion_models", "unet"]) KNOWN_CLIP_MODELS: Final[KnownDownloadables] = KnownDownloadables([ diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 4e11eb403..c21c532e4 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -19,6 +19,7 @@ from __future__ import annotations import collections import copy +import dataclasses import inspect import logging import typing @@ -37,6 +38,7 @@ from . import utils from .comfy_types import UnetWrapperFunction from .component_model.deprecation import _deprecate_method from .float import stochastic_rounding +from .gguf import move_patch_to_device, is_torch_compatible, is_quantized, GGMLOps from .hooks import EnumHookMode, _HookRef, HookGroup, EnumHookType, WeightHook, create_transformer_options_from_hooks from .lora_types import PatchDict, PatchDictKey, PatchTuple, PatchWeightTuple, ModelPatchesDictValue from .model_base import BaseModel @@ -221,6 +223,13 @@ class MemoryCounter: self.value -= used +@dataclasses.dataclass +class GGUFQuantization: + loaded_from_gguf: bool = False + mmap_released: bool = False + patch_on_device: bool = False + + class ModelPatcher(ModelManageable): def __init__(self, model: BaseModel | torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None): self.size = size @@ -257,6 +266,9 @@ class ModelPatcher(ModelManageable): self.forced_hooks: Optional[HookGroup] = None # NOTE: only used for CLIP at this time self.is_clip = False self.hook_mode = EnumHookMode.MaxSpeed + self.gguf = GGUFQuantization() + if isinstance(model, BaseModel) and model.operations == GGMLOps: + self.gguf.loaded_from_gguf = True @property def model_options(self) -> ModelOptions: @@ -361,6 +373,9 @@ class ModelPatcher(ModelManageable): n.forced_hooks = self.forced_hooks.clone() if self.forced_hooks else self.forced_hooks n.is_clip = self.is_clip n.hook_mode = self.hook_mode + n.gguf = copy.copy(self.gguf) + # todo: when is this set back to False? when would it make sense to? + n.gguf.mmap_released = False for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE): callback(self, n) @@ -612,6 +627,19 @@ class ModelPatcher(ModelManageable): weight, set_func, convert_func = get_key_weight(self.model, key) inplace_update = self.weight_inplace_update or inplace_update + # from gguf + if is_quantized(weight): + out_weight = weight.to(device_to) + patches = move_patch_to_device(self.patches[key], self.load_device if self.patch_on_device else self.offload_device) + # TODO: do we ever have legitimate duplicate patches? (i.e. patch on top of patched weight) + out_weight.patches = [(patches, key)] + if inplace_update: + utils.copy_to_param(self.model, key, out_weight) + else: + utils.set_attr_param(self.model, key, out_weight) + return + # end gguf + if key not in self.backup: self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update) @@ -648,6 +676,9 @@ class ModelPatcher(ModelManageable): return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): + if self.gguf.loaded_from_gguf: + force_patch_weights = True + with self.use_ejected(): self.unpatch_hooks() mem_counter = 0 @@ -744,6 +775,28 @@ class ModelPatcher(ModelManageable): self.model.to(device_to) mem_counter = self.model_size() + if self.gguf.loaded_from_gguf and not self.gguf.mmap_released: + # todo: when is mmap_released set to True? + linked = [] + if lowvram_model_memory > 0: + for n, m in self.model.named_modules(): + if hasattr(m, "weight"): + device = getattr(m.weight, "device", None) + if device == self.offload_device: + linked.append((n, m)) + continue + if hasattr(m, "bias"): + device = getattr(m.bias, "device", None) + if device == self.offload_device: + linked.append((n, m)) + continue + if linked and self.load_device != self.offload_device: + logger.info(f"gguf attempting to release mmap ({len(linked)})") + for n, m in linked: + # TODO: possible to OOM, find better way to detach + m.to(self.load_device).to(self.offload_device) + self.gguf.mmap_released = True + self._memory_measurements.lowvram_patch_counter += patch_counter self.model_device = device_to @@ -774,6 +827,13 @@ class ModelPatcher(ModelManageable): def unpatch_model(self, device_to=None, unpatch_weights=True): self.eject_model() + if self.gguf.loaded_from_gguf and unpatch_weights: + for p in self.model.parameters(): + if is_torch_compatible(p): + continue + patches = self.patches + if len(patches) > 0: + p.patches = [] if unpatch_weights: self.unpatch_hooks() if self._memory_measurements.model_lowvram: diff --git a/comfy/ops.py b/comfy/ops.py index 1dd215ae9..45cc30cfa 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -503,7 +503,7 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_ weight_dtype == torch.float16 and (compute_dtype == torch.float16 or compute_dtype is None) ): - logging.info("Using cublas ops") + logger.info("Using cublas ops") return cublas_ops if compute_dtype is None or weight_dtype == compute_dtype: diff --git a/comfy/sd.py b/comfy/sd.py index 74569c501..afbb102df 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -23,6 +23,7 @@ from . import model_sampling from . import sd1_clip from . import sdxl_clip from . import utils +from .component_model.deprecation import _deprecate_method from .hooks import EnumHookMode from .ldm.ace.vae.music_dcae_pipeline import MusicDCAE from .ldm.audio.autoencoder import AudioOobleckVAE @@ -58,7 +59,7 @@ from .text_encoders import sa_t5 from .text_encoders import sd2_clip from .text_encoders import sd3_clip from .text_encoders import wan -from .utils import ProgressBar +from .utils import ProgressBar, FileMetadata logger = logging.getLogger(__name__) @@ -1098,7 +1099,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o return out -def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options=None, te_model_options=None, metadata: Optional[str | dict] = None, ckpt_path=""): +def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options=None, te_model_options=None, metadata: Optional[FileMetadata] = None, ckpt_path=""): if te_model_options is None: te_model_options = {} if model_options is None: @@ -1182,7 +1183,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c return (_model_patcher, clip, vae, clipvision) -def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: Optional[str] = ""): # load unet in diffusers or regular format +def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: Optional[str] = "", metadata: Optional[FileMetadata] = None): # load unet in diffusers or regular format """ Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats. @@ -1192,6 +1193,7 @@ def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: O - dtype: Override model data type - custom_operations: Custom model operations - fp8_optimizations: Enable FP8 optimizations + metadata: file metadata Returns: ModelPatcher: A wrapped model instance that handles device management and weight loading. @@ -1217,14 +1219,14 @@ def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: O parameters = utils.calculate_parameters(sd) weight_dtype = utils.weight_dtype(sd) load_device = model_management.get_torch_device() - model_config = model_detection.model_config_from_unet(sd, "") + model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata) if model_config is not None: new_sd = sd else: new_sd = model_detection.convert_diffusers_mmdit(sd, "") if new_sd is not None: # diffusers mmdit - model_config = model_detection.model_config_from_unet(new_sd, "") + model_config = model_detection.model_config_from_unet(new_sd, "", metadata=metadata) if model_config is None: return None else: # diffusers unet @@ -1269,21 +1271,21 @@ def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: O def load_diffusion_model(unet_path, model_options: dict = None): if model_options is None: model_options = {} - sd = utils.load_torch_file(unet_path) - model = load_diffusion_model_state_dict(sd, model_options=model_options, ckpt_path=unet_path) + sd, metadata = utils.load_torch_file(unet_path, return_metadata=True) + model = load_diffusion_model_state_dict(sd, model_options=model_options, ckpt_path=unet_path, metadata=metadata) if model is None: logger.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd))) return model +@_deprecate_method(message="The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model", version="*") def load_unet(unet_path, dtype=None): - logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model") return load_diffusion_model(unet_path, model_options={"dtype": dtype}) +@_deprecate_method(message="The load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict", version="*") def load_unet_state_dict(sd, dtype=None): - logging.warning("The load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict") return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype}) diff --git a/comfy/utils.py b/comfy/utils.py index 92dcf71e4..ce272b2d3 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -31,7 +31,7 @@ import warnings from contextlib import contextmanager from pathlib import Path from pickle import UnpicklingError -from typing import Optional, Any +from typing import Optional, Any, Literal import numpy as np import safetensors.torch @@ -40,15 +40,15 @@ from PIL import Image from einops import rearrange from torch.nn.functional import interpolate from tqdm import tqdm +from typing_extensions import TypedDict, NotRequired -from comfy_api import feature_flags from . import interruption, checkpoint_pickle from .cli_args import args from .component_model import files from .component_model.deprecation import _deprecate_method from .component_model.executor_types import ExecutorToClientProgress, ProgressMessage -from .component_model.queue_types import BinaryEventTypes from .execution_context import current_execution_context +from .gguf import gguf_sd_loader from .progress import get_progress_state MMAP_TORCH_FILES = args.mmap_torch_files @@ -89,12 +89,17 @@ def _get_progress_bar_enabled(): setattr(sys.modules[__name__], 'PROGRESS_BAR_ENABLED', property(_get_progress_bar_enabled)) -def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=False): +class FileMetadata(TypedDict): + format: NotRequired[Literal["gguf"]] + + +def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=False) -> dict[str, torch.Tensor] | tuple[dict[str, torch.Tensor], Optional[FileMetadata]]: if device is None: device = torch.device("cpu") if ckpt is None: raise FileNotFoundError("the checkpoint was not found") - metadata = None + metadata: Optional[dict[str, str]] = None + sd: dict[str, torch.Tensor] = None if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): try: with safetensors.safe_open(Path(ckpt).resolve(strict=True), framework="pt", device=device.type) as f: @@ -128,6 +133,10 @@ def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=Fal sd: dict[str, torch.Tensor] = {} for checkpoint_file in checkpoint_files: sd.update(safetensors.torch.load_file(str(checkpoint_file), device=device.type)) + elif ckpt.lower().endswith(".gguf"): + # from gguf + sd = gguf_sd_loader(ckpt) + metadata = {"format": "gguf"} else: try: torch_args = {} @@ -138,7 +147,7 @@ def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=Fal pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args) else: logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt)) - pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle) + pl_sd = torch.load(ckpt, map_location=device, pickle_module=checkpoint_pickle) if "state_dict" in pl_sd: sd = pl_sd["state_dict"] else: @@ -153,14 +162,14 @@ def load_torch_file(ckpt: str, safe_load=False, device=None, return_metadata=Fal try: # wrong extension is most likely, try to load as safetensors anyway sd = safetensors.torch.load_file(Path(ckpt).resolve(strict=True), device=device.type) - return sd except Exception: msg = f"The checkpoint at {ckpt} could not be loaded as a safetensor nor a torch checkpoint. The file at the path is corrupted or unexpected. Try deleting it and downloading it again" if hasattr(exc_info, "add_note"): exc_info.add_note(msg) else: logger.error(msg, exc_info=exc_info) - raise exc_info + if sd is None: + raise exc_info return (sd, metadata) if return_metadata else sd @@ -1125,6 +1134,7 @@ def _progress_bar_update(value: float, total: float, preview_image_or_data: Opti get_progress_state().update_progress(node_id, value, total, preview_image_or_data) server.send_sync("progress", progress, client_id) + def set_progress_bar_enabled(enabled: bool): warnings.warn( "The global method 'set_progress_bar_enabled' is deprecated and will be removed in a future version. Use current_execution_context().server.receive_all_progress_notifications instead.", diff --git a/pyproject.toml b/pyproject.toml index e610094e4..b7f85d4fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,6 +107,7 @@ dependencies = [ "setuptools", "alembic", "SQLAlchemy", + "gguf", ] [build-system]