""" Copyright 2025 "City96" and Benjanin Berman Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import argparse import logging import os import warnings import gguf import torch from safetensors.torch import load_file, save_file from sentencepiece import sentencepiece_model_pb2 as model from tqdm import tqdm from .lora import calculate_weight from .model_management import device_supports_non_blocking from .ops import cast_to, manual_cast logger = logging.getLogger(__name__) QUANTIZATION_THRESHOLD = 1024 REARRANGE_THRESHOLD = 512 MAX_TENSOR_NAME_LENGTH = 127 MAX_TENSOR_DIMS = 4 TORCH_COMPATIBLE_QTYPES = (None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16) IMG_ARCH_LIST = {"flux", "sd1", "sdxl", "sd3", "aura", "hidream", "cosmos", "ltxv", "hyvid", "wan"} TXT_ARCH_LIST = {"t5", "t5encoder", "llama"} class ModelTemplate: arch = "invalid" # string describing architecture shape_fix = False # whether to reshape tensors keys_detect = [] # list of lists to match in state dict keys_banned = [] # list of keys that should mark model as invalid for conversion keys_hiprec = [] # list of keys that need to be kept in fp32 for some reason keys_ignore = [] # list of strings to ignore keys by when found def handle_nd_tensor(self, key, data): raise NotImplementedError(f"Tensor detected that exceeds dims supported by C++ code! ({key} @ {data.shape})") class ModelFlux(ModelTemplate): arch = "flux" keys_detect = [ ("transformer_blocks.0.attn.norm_added_k.weight",), ("double_blocks.0.img_attn.proj.weight",), ] keys_banned = ["transformer_blocks.0.attn.norm_added_k.weight", ] class ModelSD3(ModelTemplate): arch = "sd3" keys_detect = [ ("transformer_blocks.0.attn.add_q_proj.weight",), ("joint_blocks.0.x_block.attn.qkv.weight",), ] keys_banned = ["transformer_blocks.0.attn.add_q_proj.weight", ] class ModelAura(ModelTemplate): arch = "aura" keys_detect = [ ("double_layers.3.modX.1.weight",), ("joint_transformer_blocks.3.ff_context.out_projection.weight",), ] keys_banned = ["joint_transformer_blocks.3.ff_context.out_projection.weight", ] class ModelHiDream(ModelTemplate): arch = "hidream" keys_detect = [ ( "caption_projection.0.linear.weight", "double_stream_blocks.0.block.ff_i.shared_experts.w3.weight" ) ] keys_hiprec = [ # nn.parameter, can't load from BF16 ver ".ff_i.gate.weight", "img_emb.emb_pos" ] class CosmosPredict2(ModelTemplate): arch = "cosmos" keys_detect = [ ( "blocks.0.mlp.layer1.weight", "blocks.0.adaln_modulation_cross_attn.1.weight", ) ] keys_hiprec = ["pos_embedder"] keys_ignore = ["_extra_state", "accum_"] class ModelHyVid(ModelTemplate): arch = "hyvid" keys_detect = [ ( "double_blocks.0.img_attn_proj.weight", "txt_in.individual_token_refiner.blocks.1.self_attn_qkv.weight", ) ] def handle_nd_tensor(self, key, data): # hacky but don't have any better ideas path = f"./fix_5d_tensors_{self.arch}.safetensors" # TODO: somehow get a path here?? if os.path.isfile(path): raise RuntimeError(f"5D tensor fix file already exists! {path}") fsd = {key: torch.from_numpy(data)} tqdm.write(f"5D key found in state dict! Manual fix required! - {key} {data.shape}") save_file(fsd, path) class ModelWan(ModelHyVid): arch = "wan" keys_detect = [ ( "blocks.0.self_attn.norm_q.weight", "text_embedding.2.weight", "head.modulation", ) ] keys_hiprec = [ ".modulation" # nn.parameter, can't load from BF16 ver ] class ModelLTXV(ModelTemplate): arch = "ltxv" keys_detect = [ ( "adaln_single.emb.timestep_embedder.linear_2.weight", "transformer_blocks.27.scale_shift_table", "caption_projection.linear_2.weight", ) ] keys_hiprec = [ "scale_shift_table" # nn.parameter, can't load from BF16 base quant ] class ModelSDXL(ModelTemplate): arch = "sdxl" shape_fix = True keys_detect = [ ("down_blocks.0.downsamplers.0.conv.weight", "add_embedding.linear_1.weight",), ( "input_blocks.3.0.op.weight", "input_blocks.6.0.op.weight", "output_blocks.2.2.conv.weight", "output_blocks.5.2.conv.weight", ), # Non-diffusers ("label_emb.0.0.weight",), ] class ModelSD1(ModelTemplate): arch = "sd1" shape_fix = True keys_detect = [ ("down_blocks.0.downsamplers.0.conv.weight",), ( "input_blocks.3.0.op.weight", "input_blocks.6.0.op.weight", "input_blocks.9.0.op.weight", "output_blocks.2.1.conv.weight", "output_blocks.5.2.conv.weight", "output_blocks.8.2.conv.weight" ), # Non-diffusers ] # The architectures are checked in order and the first successful match terminates the search. arch_list = [ModelFlux, ModelSD3, ModelAura, ModelHiDream, CosmosPredict2, ModelLTXV, ModelHyVid, ModelWan, ModelSDXL, ModelSD1] def is_model_arch(model, state_dict): # check if model is correct matched = False invalid = False for match_list in model.keys_detect: if all(key in state_dict for key in match_list): matched = True invalid = any(key in state_dict for key in model.keys_banned) break assert not invalid, "Model architecture not allowed for conversion! (i.e. reference VS diffusers format)" return matched def detect_arch(state_dict): model_arch = None for arch in arch_list: if is_model_arch(arch, state_dict): model_arch = arch() break assert model_arch is not None, "Unknown model architecture!" return model_arch def parse_args(): parser = argparse.ArgumentParser(description="Generate F16 GGUF files from single UNET") parser.add_argument("--src", required=True, help="Source model ckpt file.") parser.add_argument("--dst", help="Output unet gguf file.") args = parser.parse_args() if not os.path.isfile(args.src): parser.error("No input provided!") return args def strip_prefix(state_dict): # prefix for mixed state dict prefix = None for pfx in ["model.diffusion_model.", "model."]: if any([x.startswith(pfx) for x in state_dict.keys()]): prefix = pfx break # prefix for uniform state dict if prefix is None: for pfx in ["net."]: if all([x.startswith(pfx) for x in state_dict.keys()]): prefix = pfx break # strip prefix if found if prefix is not None: logger.info(f"State dict prefix found: '{prefix}'") sd = {} for k, v in state_dict.items(): if prefix not in k: continue k = k.replace(prefix, "") sd[k] = v else: logger.debug("State dict has no prefix") sd = state_dict return sd def load_state_dict(path): if any(path.endswith(x) for x in [".ckpt", ".pt", ".bin", ".pth"]): state_dict = torch.load(path, map_location="cpu", weights_only=True) for subkey in ["model", "module"]: if subkey in state_dict: state_dict = state_dict[subkey] break if len(state_dict) < 20: raise RuntimeError(f"pt subkey load failed: {state_dict.keys()}") else: state_dict = load_file(path) return strip_prefix(state_dict) def handle_tensors(writer, state_dict, model_arch): name_lengths = tuple(sorted( ((key, len(key)) for key in state_dict.keys()), key=lambda item: item[1], reverse=True, )) if not name_lengths: return max_name_len = name_lengths[0][1] if max_name_len > MAX_TENSOR_NAME_LENGTH: bad_list = ", ".join(f"{key!r} ({namelen})" for key, namelen in name_lengths if namelen > MAX_TENSOR_NAME_LENGTH) raise ValueError(f"Can only handle tensor names up to {MAX_TENSOR_NAME_LENGTH} characters. Tensors exceeding the limit: {bad_list}") for key, data in tqdm(state_dict.items()): old_dtype = data.dtype if any(x in key for x in model_arch.keys_ignore): tqdm.write(f"Filtering ignored key: '{key}'") continue if data.dtype == torch.bfloat16: data = data.to(torch.float32).numpy() # this is so we don't break torch 2.0.X elif data.dtype in [getattr(torch, "float8_e4m3fn", "_invalid"), getattr(torch, "float8_e5m2", "_invalid")]: data = data.to(torch.float16).numpy() else: data = data.numpy() n_dims = len(data.shape) data_shape = data.shape if old_dtype == torch.bfloat16: data_qtype = gguf.GGMLQuantizationType.BF16 # elif old_dtype == torch.float32: # data_qtype = gguf.GGMLQuantizationType.F32 else: data_qtype = gguf.GGMLQuantizationType.F16 # The max no. of dimensions that can be handled by the quantization code is 4 if len(data.shape) > MAX_TENSOR_DIMS: model_arch.handle_nd_tensor(key, data) continue # needs to be added back later # get number of parameters (AKA elements) in this tensor n_params = 1 for dim_size in data_shape: n_params *= dim_size if old_dtype in (torch.float32, torch.bfloat16): if n_dims == 1: # one-dimensional tensors should be kept in F32 # also speeds up inference due to not dequantizing data_qtype = gguf.GGMLQuantizationType.F32 elif n_params <= QUANTIZATION_THRESHOLD: # very small tensors data_qtype = gguf.GGMLQuantizationType.F32 elif any(x in key for x in model_arch.keys_hiprec): # tensors that require max precision data_qtype = gguf.GGMLQuantizationType.F32 if (model_arch.shape_fix # NEVER reshape for models such as flux and n_dims > 1 # Skip one-dimensional tensors and n_params >= REARRANGE_THRESHOLD # Only rearrange tensors meeting the size requirement and (n_params / 256).is_integer() # Rearranging only makes sense if total elements is divisible by 256 and not (data.shape[-1] / 256).is_integer() # Only need to rearrange if the last dimension is not divisible by 256 ): orig_shape = data.shape data = data.reshape(n_params // 256, 256) writer.add_array(f"comfy.gguf.orig_shape.{key}", tuple(int(dim) for dim in orig_shape)) try: data = gguf.quants.quantize(data, data_qtype) except (AttributeError, gguf.QuantError) as e: tqdm.write(f"falling back to F16: {e}") data_qtype = gguf.GGMLQuantizationType.F16 data = gguf.quants.quantize(data, data_qtype) new_name = key # do we need to rename? shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}" tqdm.write(f"{f'%-{max_name_len + 4}s' % f'{new_name}'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") writer.add_tensor(new_name, data, raw_dtype=data_qtype) def convert_file(path, dst_path=None, interact=True, overwrite=False): # load & run model detection logic state_dict = load_state_dict(path) model_arch = detect_arch(state_dict) logger.info(f"* Architecture detected from input: {model_arch.arch}") # detect & set dtype for output file dtypes = [x.dtype for x in state_dict.values()] dtypes = {x: dtypes.count(x) for x in set(dtypes)} main_dtype = max(dtypes, key=dtypes.get) if main_dtype == torch.bfloat16: ftype_name = "BF16" ftype_gguf = gguf.LlamaFileType.MOSTLY_BF16 # elif main_dtype == torch.float32: # ftype_name = "F32" # ftype_gguf = None else: ftype_name = "F16" ftype_gguf = gguf.LlamaFileType.MOSTLY_F16 if dst_path is None: dst_path = f"{os.path.splitext(path)[0]}-{ftype_name}.gguf" elif "{ftype}" in dst_path: # lcpp logic dst_path = dst_path.replace("{ftype}", ftype_name) if os.path.isfile(dst_path) and not overwrite: if interact: input("Output exists enter to continue or ctrl+c to abort!") else: raise OSError("Output exists and overwriting is disabled!") # handle actual file writer = gguf.GGUFWriter(path=None, arch=model_arch.arch) writer.add_quantization_version(gguf.GGML_QUANT_VERSION) if ftype_gguf is not None: writer.add_file_type(ftype_gguf) handle_tensors(writer, state_dict, model_arch) writer.write_header_to_file(path=dst_path) writer.write_kv_data_to_file() writer.write_tensors_to_file(progress=True) writer.close() fix = f"./fix_5d_tensors_{model_arch.arch}.safetensors" if os.path.isfile(fix): logger.warning(f"\n### Warning! Fix file found at '{fix}'") logger.warning(" you most likely need to run 'fix_5d_tensors.py' after quantization.") return dst_path, model_arch def is_torch_compatible(tensor): return tensor is None or getattr(tensor, "tensor_type", None) in TORCH_COMPATIBLE_QTYPES def is_quantized(tensor): return not is_torch_compatible(tensor) def dequantize_tensor(tensor, dtype=None, dequant_dtype=None): qtype = getattr(tensor, "tensor_type", None) oshape = getattr(tensor, "tensor_shape", tensor.shape) if qtype in TORCH_COMPATIBLE_QTYPES: return tensor.to(dtype) elif qtype in dequantize_functions: dequant_dtype = dtype if dequant_dtype == "target" else dequant_dtype return dequantize(tensor.data, qtype, oshape, dtype=dequant_dtype).to(dtype) else: # this is incredibly slow tqdm.write(f"Falling back to numpy dequant for qtype: {qtype}") new = gguf.quants.dequantize(tensor.cpu().numpy(), qtype) return torch.from_numpy(new).to(tensor.device, dtype=dtype) def dequantize(data, qtype, oshape, dtype=None): """ Dequantize tensor back to usable shape/dtype """ block_size, type_size = gguf.GGML_QUANT_SIZES[qtype] dequantize_blocks = dequantize_functions[qtype] rows = data.reshape( (-1, data.shape[-1]) ).view(torch.uint8) n_blocks = rows.numel() // type_size blocks = rows.reshape((n_blocks, type_size)) blocks = dequantize_blocks(blocks, block_size, type_size, dtype) return blocks.reshape(oshape) def to_uint32(x): # no uint32 :( x = x.view(torch.uint8).to(torch.int32) return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1) def split_block_dims(blocks, *args): n_max = blocks.shape[1] dims = list(args) + [n_max - sum(args)] return torch.split(blocks, dims, dim=1) # Full weights # def dequantize_blocks_BF16(blocks, block_size, type_size, dtype=None): return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32) # Legacy Quants # def dequantize_blocks_Q8_0(blocks, block_size, type_size, dtype=None): d, x = split_block_dims(blocks, 2) d = d.view(torch.float16).to(dtype) x = x.view(torch.int8) return (d * x) def dequantize_blocks_Q5_1(blocks, block_size, type_size, dtype=None): n_blocks = blocks.shape[0] d, m, qh, qs = split_block_dims(blocks, 2, 2, 4) d = d.view(torch.float16).to(dtype) m = m.view(torch.float16).to(dtype) qh = to_uint32(qh) qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1) qh = (qh & 1).to(torch.uint8) ql = (ql & 0x0F).reshape((n_blocks, -1)) qs = (ql | (qh << 4)) return (d * qs) + m def dequantize_blocks_Q5_0(blocks, block_size, type_size, dtype=None): n_blocks = blocks.shape[0] d, qh, qs = split_block_dims(blocks, 2, 4) d = d.view(torch.float16).to(dtype) qh = to_uint32(qh) qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1) qh = (qh & 1).to(torch.uint8) ql = (ql & 0x0F).reshape(n_blocks, -1) qs = (ql | (qh << 4)).to(torch.int8) - 16 return (d * qs) def dequantize_blocks_Q4_1(blocks, block_size, type_size, dtype=None): n_blocks = blocks.shape[0] d, m, qs = split_block_dims(blocks, 2, 2) d = d.view(torch.float16).to(dtype) m = m.view(torch.float16).to(dtype) qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1) qs = (qs & 0x0F).reshape(n_blocks, -1) return (d * qs) + m def dequantize_blocks_Q4_0(blocks, block_size, type_size, dtype=None): n_blocks = blocks.shape[0] d, qs = split_block_dims(blocks, 2) d = d.view(torch.float16).to(dtype) qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8 return (d * qs) # K Quants # QK_K = 256 K_SCALE_SIZE = 12 def get_scale_min(scales): n_blocks = scales.shape[0] scales = scales.view(torch.uint8) scales = scales.reshape((n_blocks, 3, 4)) d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2) sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1) min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1) return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8))) def dequantize_blocks_Q6_K(blocks, block_size, type_size, dtype=None): n_blocks = blocks.shape[0] ql, qh, scales, d, = split_block_dims(blocks, QK_K // 2, QK_K // 4, QK_K // 16) scales = scales.view(torch.int8).to(dtype) d = d.view(torch.float16).to(dtype) d = (d * scales).reshape((n_blocks, QK_K // 16, 1)) ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) ql = (ql & 0x0F).reshape((n_blocks, -1, 32)) qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1)) qh = (qh & 0x03).reshape((n_blocks, -1, 32)) q = (ql | (qh << 4)).to(torch.int8) - 32 q = q.reshape((n_blocks, QK_K // 16, -1)) return (d * q).reshape((n_blocks, QK_K)) def dequantize_blocks_Q5_K(blocks, block_size, type_size, dtype=None): n_blocks = blocks.shape[0] d, dmin, scales, qh, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE, QK_K // 8) d = d.view(torch.float16).to(dtype) dmin = dmin.view(torch.float16).to(dtype) sc, m = get_scale_min(scales) d = (d * sc).reshape((n_blocks, -1, 1)) dm = (dmin * m).reshape((n_blocks, -1, 1)) ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([i for i in range(8)], device=d.device, dtype=torch.uint8).reshape((1, 1, 8, 1)) ql = (ql & 0x0F).reshape((n_blocks, -1, 32)) qh = (qh & 0x01).reshape((n_blocks, -1, 32)) q = (ql | (qh << 4)) return (d * q - dm).reshape((n_blocks, QK_K)) def dequantize_blocks_Q4_K(blocks, block_size, type_size, dtype=None): n_blocks = blocks.shape[0] d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE) d = d.view(torch.float16).to(dtype) dmin = dmin.view(torch.float16).to(dtype) sc, m = get_scale_min(scales) d = (d * sc).reshape((n_blocks, -1, 1)) dm = (dmin * m).reshape((n_blocks, -1, 1)) qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1)) qs = (qs & 0x0F).reshape((n_blocks, -1, 32)) return (d * qs - dm).reshape((n_blocks, QK_K)) def dequantize_blocks_Q3_K(blocks, block_size, type_size, dtype=None): n_blocks = blocks.shape[0] hmask, qs, scales, d = split_block_dims(blocks, QK_K // 8, QK_K // 4, 12) d = d.view(torch.float16).to(dtype) lscales, hscales = scales[:, :8], scales[:, 8:] lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 2, 1)) lscales = lscales.reshape((n_blocks, 16)) hscales = hscales.reshape((n_blocks, 1, 4)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 4, 1)) hscales = hscales.reshape((n_blocks, 16)) scales = (lscales & 0x0F) | ((hscales & 0x03) << 4) scales = (scales.to(torch.int8) - 32) dl = (d * scales).reshape((n_blocks, 16, 1)) ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1)) qh = hmask.reshape(n_blocks, -1, 1, 32) >> torch.tensor([i for i in range(8)], device=d.device, dtype=torch.uint8).reshape((1, 1, 8, 1)) ql = ql.reshape((n_blocks, 16, QK_K // 16)) & 3 qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & 1) ^ 1 q = (ql.to(torch.int8) - (qh << 2).to(torch.int8)) return (dl * q).reshape((n_blocks, QK_K)) def dequantize_blocks_Q2_K(blocks, block_size, type_size, dtype=None): n_blocks = blocks.shape[0] scales, qs, d, dmin = split_block_dims(blocks, QK_K // 16, QK_K // 4, 2) d = d.view(torch.float16).to(dtype) dmin = dmin.view(torch.float16).to(dtype) # (n_blocks, 16, 1) dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1)) ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1)) shift = torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1)) qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & 3 qs = qs.reshape((n_blocks, QK_K // 16, 16)) qs = dl * qs - ml return qs.reshape((n_blocks, -1)) dequantize_functions = { gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16, gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0, gguf.GGMLQuantizationType.Q5_1: dequantize_blocks_Q5_1, gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0, gguf.GGMLQuantizationType.Q4_1: dequantize_blocks_Q4_1, gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0, gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K, gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K, gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K, gguf.GGMLQuantizationType.Q3_K: dequantize_blocks_Q3_K, gguf.GGMLQuantizationType.Q2_K: dequantize_blocks_Q2_K, } # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) def get_orig_shape(reader, tensor_name): field_key = f"comfy.gguf.orig_shape.{tensor_name}" field = reader.get_field(field_key) if field is None: return None # Has original shape metadata, so we try to decode it. if len(field.types) != 2 or field.types[0] != gguf.GGUFValueType.ARRAY or field.types[1] != gguf.GGUFValueType.INT32: raise TypeError(f"Bad original shape metadata for {field_key}: Expected ARRAY of INT32, got {field.types}") return torch.Size(tuple(int(field.parts[part_idx][0]) for part_idx in field.data)) def get_field(reader, field_name, field_type): field = reader.get_field(field_name) if field is None: return None elif field_type == str: # extra check here as this is used for checking arch string if len(field.types) != 1 or field.types[0] != gguf.GGUFValueType.STRING: raise TypeError(f"Bad type for GGUF {field_name} key: expected string, got {field.types!r}") return str(field.parts[field.data[-1]], encoding="utf-8") elif field_type in [int, float, bool]: return field_type(field.parts[field.data[-1]]) else: raise TypeError(f"Unknown field type {field_type}") def get_list_field(reader, field_name, field_type): field = reader.get_field(field_name) if field is None: return None elif field_type == str: return tuple(str(field.parts[part_idx], encoding="utf-8") for part_idx in field.data) elif field_type in [int, float, bool]: return tuple(field_type(field.parts[part_idx][0]) for part_idx in field.data) else: raise TypeError(f"Unknown field type {field_type}") def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=False, is_text_model=False) -> dict: """ Read state dict as fake tensors """ reader = gguf.GGUFReader(path) # filter and strip prefix has_prefix = False if handle_prefix is not None: prefix_len = len(handle_prefix) tensor_names = set(tensor.name for tensor in reader.tensors) has_prefix = any(s.startswith(handle_prefix) for s in tensor_names) else: prefix_len = 0 tensors = [] for tensor in reader.tensors: sd_key = tensor_name = tensor.name if has_prefix: if not tensor_name.startswith(handle_prefix): continue sd_key = tensor_name[prefix_len:] tensors.append((sd_key, tensor)) # detect and verify architecture compat = None arch_str = get_field(reader, "general.architecture", str) if arch_str in [None, "pig"]: if is_text_model: raise ValueError(f"This text model is incompatible with llama.cpp!\nConsider using the safetensors version\n({path})") compat = "sd.cpp" if arch_str is None else arch_str # import here to avoid changes to convert.py breaking regular models try: arch_str = detect_arch(set(val[0] for val in tensors)).arch except Exception as e: raise ValueError(f"This model is not currently supported - ({e})") elif arch_str not in TXT_ARCH_LIST and is_text_model: raise ValueError(f"Unexpected text model architecture type in GGUF file: {arch_str!r}") elif arch_str not in IMG_ARCH_LIST and not is_text_model: raise ValueError(f"Unexpected architecture type in GGUF file: {arch_str!r}") if compat: logger.warning(f"Warning: This gguf model file is loaded in compatibility mode '{compat}' [arch:{arch_str}]") # main loading loop state_dict = {} qtype_dict = {} for sd_key, tensor in tensors: tensor_name = tensor.name # torch_tensor = torch.from_numpy(tensor.data) # mmap # NOTE: line above replaced with this block to avoid persistent numpy warning about mmap with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="The given NumPy array is not writable") torch_tensor = torch.from_numpy(tensor.data) # mmap shape = get_orig_shape(reader, tensor_name) if shape is None: shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape))) # Workaround for stable-diffusion.cpp SDXL detection. if compat == "sd.cpp" and arch_str == "sdxl": if any([tensor_name.endswith(x) for x in (".proj_in.weight", ".proj_out.weight")]): while len(shape) > 2 and shape[-1] == 1: shape = shape[:-1] # add to state dict if tensor.tensor_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}: torch_tensor = torch_tensor.view(*shape) state_dict[sd_key] = GGMLTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape) # keep track of loaded tensor types tensor_type_str = getattr(tensor.tensor_type, "name", repr(tensor.tensor_type)) qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1 # print loaded tensor type counts logger.info("gguf qtypes: " + ", ".join(f"{k} ({v})" for k, v in qtype_dict.items())) # mark largest tensor for vram estimation qsd = {k: v for k, v in state_dict.items() if is_quantized(v)} if len(qsd) > 0: max_key = max(qsd.keys(), key=lambda k: qsd[k].numel()) state_dict[max_key].is_largest_weight = True if return_arch: return (state_dict, arch_str) return state_dict # for remapping llama.cpp -> original key names T5_SD_MAP = { "enc.": "encoder.", ".blk.": ".block.", "token_embd": "shared", "output_norm": "final_layer_norm", "attn_q": "layer.0.SelfAttention.q", "attn_k": "layer.0.SelfAttention.k", "attn_v": "layer.0.SelfAttention.v", "attn_o": "layer.0.SelfAttention.o", "attn_norm": "layer.0.layer_norm", "attn_rel_b": "layer.0.SelfAttention.relative_attention_bias", "ffn_up": "layer.1.DenseReluDense.wi_1", "ffn_down": "layer.1.DenseReluDense.wo", "ffn_gate": "layer.1.DenseReluDense.wi_0", "ffn_norm": "layer.1.layer_norm", } LLAMA_SD_MAP = { "blk.": "model.layers.", "attn_norm": "input_layernorm", "attn_q": "self_attn.q_proj", "attn_k": "self_attn.k_proj", "attn_v": "self_attn.v_proj", "attn_output": "self_attn.o_proj", "ffn_up": "mlp.up_proj", "ffn_down": "mlp.down_proj", "ffn_gate": "mlp.gate_proj", "ffn_norm": "post_attention_layernorm", "token_embd": "model.embed_tokens", "output_norm": "model.norm", "output.weight": "lm_head.weight", } def sd_map_replace(raw_sd, key_map): sd = {} for k, v in raw_sd.items(): for s, d in key_map.items(): k = k.replace(s, d) sd[k] = v return sd def llama_permute(raw_sd, n_head, n_head_kv): # Reverse version of LlamaModel.permute in llama.cpp convert script sd = {} permute = lambda x, h: x.reshape(h, x.shape[0] // h // 2, 2, *x.shape[1:]).swapaxes(1, 2).reshape(x.shape) for k, v in raw_sd.items(): if k.endswith(("q_proj.weight", "q_proj.bias")): v.data = permute(v.data, n_head) if k.endswith(("k_proj.weight", "k_proj.bias")): v.data = permute(v.data, n_head_kv) sd[k] = v return sd def gguf_tokenizer_loader(path, temb_shape): # convert gguf tokenizer to spiece logger.info("Attempting to recreate sentencepiece tokenizer from GGUF file metadata...") spm = model.ModelProto() reader = gguf.GGUFReader(path) if get_field(reader, "tokenizer.ggml.model", str) == "t5": if temb_shape == (256384, 4096): # probably UMT5 spm.trainer_spec.model_type == 1 # Unigram (do we have a T5 w/ BPE?) else: raise NotImplementedError("Unknown model, can't set tokenizer!") else: raise NotImplementedError("Unknown model, can't set tokenizer!") spm.normalizer_spec.add_dummy_prefix = get_field(reader, "tokenizer.ggml.add_space_prefix", bool) spm.normalizer_spec.remove_extra_whitespaces = get_field(reader, "tokenizer.ggml.remove_extra_whitespaces", bool) tokens = get_list_field(reader, "tokenizer.ggml.tokens", str) scores = get_list_field(reader, "tokenizer.ggml.scores", float) toktypes = get_list_field(reader, "tokenizer.ggml.token_type", int) for idx, (token, score, toktype) in enumerate(zip(tokens, scores, toktypes)): # # These aren't present in the original? # if toktype == 5 and idx >= temb_shape[0]%1000): # continue piece = spm.SentencePiece() piece.piece = token piece.score = score piece.type = toktype spm.pieces.append(piece) # unsure if any of these are correct spm.trainer_spec.byte_fallback = True spm.trainer_spec.vocab_size = len(tokens) # split off unused? spm.trainer_spec.max_sentence_length = 4096 spm.trainer_spec.eos_id = get_field(reader, "tokenizer.ggml.eos_token_id", int) spm.trainer_spec.pad_id = get_field(reader, "tokenizer.ggml.padding_token_id", int) logger.info(f"Created tokenizer with vocab size of {len(spm.pieces)}") del reader return torch.ByteTensor(list(spm.SerializeToString())) def gguf_clip_loader(path): sd, arch = gguf_sd_loader(path, return_arch=True, is_text_model=True) if arch in {"t5", "t5encoder"}: temb_key = "token_embd.weight" if temb_key in sd and sd[temb_key].shape == (256384, 4096): # non-standard Comfy-Org tokenizer sd["spiece_model"] = gguf_tokenizer_loader(path, sd[temb_key].shape) # TODO: dequantizing token embed here is janky but otherwise we OOM due to tensor being massive. logger.warning(f"Dequantizing {temb_key} to prevent runtime OOM.") sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16) sd = sd_map_replace(sd, T5_SD_MAP) elif arch in {"llama"}: # TODO: pass model_options["vocab_size"] to loader somehow temb_key = "token_embd.weight" if temb_key in sd and sd[temb_key].shape[0] >= (64 * 1024): # See note above for T5. logger.warning(f"Dequantizing {temb_key} to prevent runtime OOM.") sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16) sd = sd_map_replace(sd, LLAMA_SD_MAP) sd = llama_permute(sd, 32, 8) # L3 else: pass return sd # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) def chained_hasattr(obj, chained_attr): probe = obj for attr in chained_attr.split('.'): if hasattr(probe, attr): probe = getattr(probe, attr) else: return False return True # A bakcward and forward compatible way to get `torch.compiler.disable`. def get_torch_compiler_disable_decorator(): def dummy_decorator(*args, **kwargs): def noop(x): return x return noop from packaging import version if not chained_hasattr(torch, "compiler.disable"): logger.debug("ComfyUI-GGUF: Torch too old for torch.compile - bypassing") return dummy_decorator # torch too old elif version.parse(torch.__version__) >= version.parse("2.8"): logger.debug("ComfyUI-GGUF: Allowing full torch compile") return dummy_decorator # torch compile works if chained_hasattr(torch, "_dynamo.config.nontraceable_tensor_subclasses"): logger.debug("ComfyUI-GGUF: Allowing full torch compile (nightly)") return dummy_decorator # torch compile works, nightly before 2.8 release else: logger.debug("ComfyUI-GGUF: Partial torch compile only, consider updating pytorch") return torch.compiler.disable torch_compiler_disable = get_torch_compiler_disable_decorator() class GGMLTensor(torch.Tensor): """ Main tensor-like class for storing quantized weights """ def __init__(self, *args, tensor_type, tensor_shape, patches=[], **kwargs): super().__init__() self.tensor_type = tensor_type self.tensor_shape = tensor_shape self.patches = patches def __new__(cls, *args, tensor_type, tensor_shape, patches=[], **kwargs): return super().__new__(cls, *args, **kwargs) def to(self, *args, **kwargs): new = super().to(*args, **kwargs) new.tensor_type = getattr(self, "tensor_type", None) new.tensor_shape = getattr(self, "tensor_shape", new.data.shape) new.patches = getattr(self, "patches", []).copy() return new def clone(self, *args, **kwargs): return self def detach(self, *args, **kwargs): return self def copy_(self, *args, **kwargs): # fixes .weight.copy_ in comfy/clip_model/CLIPTextModel try: return super().copy_(*args, **kwargs) except Exception as e: logger.warning(f"ignoring 'copy_' on tensor: {e}") def new_empty(self, size, *args, **kwargs): # Intel Arc fix, ref#50 new_tensor = super().new_empty(size, *args, **kwargs) return GGMLTensor( new_tensor, tensor_type=getattr(self, "tensor_type", None), tensor_shape=size, patches=getattr(self, "patches", []).copy() ) @property def shape(self): if not hasattr(self, "tensor_shape"): self.tensor_shape = self.size() return self.tensor_shape class GGMLLayer(torch.nn.Module): """ This (should) be responsible for de-quantizing on the fly """ comfy_cast_weights = True dequant_dtype = None patch_dtype = None largest_layer = False torch_compatible_tensor_types = {None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16} def is_ggml_quantized(self, *, weight=None, bias=None): if weight is None: weight = self.weight if bias is None: bias = self.bias return is_quantized(weight) or is_quantized(bias) def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): weight, bias = state_dict.get(f"{prefix}weight"), state_dict.get(f"{prefix}bias") # NOTE: using modified load for linear due to not initializing on creation, see GGMLOps todo if self.is_ggml_quantized(weight=weight, bias=bias) or isinstance(self, torch.nn.Linear): return self.ggml_load_from_state_dict(state_dict, prefix, *args, **kwargs) # Not strictly required, but fixes embedding shape mismatch. Threshold set in loader.py if isinstance(self, torch.nn.Embedding) and self.weight.shape[0] >= (64 * 1024): return self.ggml_load_from_state_dict(state_dict, prefix, *args, **kwargs) return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) def ggml_load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): prefix_len = len(prefix) for k, v in state_dict.items(): if k[prefix_len:] == "weight": self.weight = torch.nn.Parameter(v, requires_grad=False) elif k[prefix_len:] == "bias" and v is not None: self.bias = torch.nn.Parameter(v, requires_grad=False) else: unexpected_keys.append(k) # For Linear layer with missing weight if self.weight is None and isinstance(self, torch.nn.Linear): v = torch.zeros(self.in_features, self.out_features) self.weight = torch.nn.Parameter(v, requires_grad=False) missing_keys.append(prefix + "weight") # for vram estimation (TODO: less fragile logic?) if getattr(self.weight, "is_largest_weight", False): self.largest_layer = True def _save_to_state_dict(self, *args, **kwargs): if self.is_ggml_quantized(): return self.ggml_save_to_state_dict(*args, **kwargs) return super()._save_to_state_dict(*args, **kwargs) def ggml_save_to_state_dict(self, destination, prefix, keep_vars): # This is a fake state dict for vram estimation weight = torch.zeros_like(self.weight, device=torch.device("meta")) destination[prefix + "weight"] = weight if self.bias is not None: bias = torch.zeros_like(self.bias, device=torch.device("meta")) destination[prefix + "bias"] = bias # Take into account space required for dequantizing the largest tensor if self.largest_layer: shape = getattr(self.weight, "tensor_shape", self.weight.shape) dtype = self.dequant_dtype or torch.float16 temp = torch.empty(*shape, device=torch.device("meta"), dtype=dtype) destination[prefix + "temp.weight"] = temp return def get_weight(self, tensor, dtype): if tensor is None: return # consolidate and load patches to GPU in async patch_list = [] device = tensor.device for patches, key in getattr(tensor, "patches", []): patch_list += move_patch_to_device(patches, device) # dequantize tensor while patches load weight = dequantize_tensor(tensor, dtype, self.dequant_dtype) # prevent propagating custom tensor class if isinstance(weight, GGMLTensor): weight = torch.Tensor(weight) # apply patches if len(patch_list) > 0: if self.patch_dtype is None: weight = calculate_weight(patch_list, weight, key) else: # for testing, may degrade image quality patch_dtype = dtype if self.patch_dtype == "target" else self.patch_dtype weight = calculate_weight(patch_list, weight, key, patch_dtype) return weight @torch_compiler_disable() def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): if input is not None: if dtype is None: dtype = getattr(input, "dtype", torch.float32) if bias_dtype is None: bias_dtype = dtype if device is None: device = input.device bias = None non_blocking = device_supports_non_blocking(device) if s.bias is not None: bias = s.get_weight(s.bias.to(device), dtype) bias = cast_to(bias, bias_dtype, device, non_blocking=non_blocking, copy=False) weight = s.get_weight(s.weight.to(device), dtype) weight = cast_to(weight, dtype, device, non_blocking=non_blocking, copy=False) return weight, bias def forward_comfy_cast_weights(self, input, *args, **kwargs): if self.is_ggml_quantized(): out = self.forward_ggml_cast_weights(input, *args, **kwargs) else: # this is from the mixin out = super().forward_comfy_cast_weights(input, *args, **kwargs) # pylint: disable=no-member # non-ggml forward might still propagate custom tensor class if isinstance(out, GGMLTensor): out = torch.Tensor(out) return out def forward_ggml_cast_weights(self, input): raise NotImplementedError class GGMLOps(manual_cast): """ Dequantize weights on the fly before doing the compute """ class Linear(GGMLLayer, manual_cast.Linear): dequant_dtype = None patch_dtype = None def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): torch.nn.Module.__init__(self) # TODO: better workaround for reserved memory spike on windows # Issue is with `torch.empty` still reserving the full memory for the layer # Windows doesn't over-commit memory so without this 24GB+ of pagefile is used self.in_features = in_features self.out_features = out_features self.weight = None self.bias = None def forward_ggml_cast_weights(self, input): weight, bias = self.cast_bias_weight(input) return torch.nn.functional.linear(input, weight, bias) class Conv2d(GGMLLayer, manual_cast.Conv2d): def forward_ggml_cast_weights(self, input): weight, bias = self.cast_bias_weight(input) return self._conv_forward(input, weight, bias) class Embedding(GGMLLayer, manual_cast.Embedding): def forward_ggml_cast_weights(self, input, out_dtype=None): output_dtype = out_dtype if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16: out_dtype = None weight, _bias = self.cast_bias_weight(self, device=input.device, dtype=out_dtype) return torch.nn.functional.embedding( input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse ).to(dtype=output_dtype) class LayerNorm(GGMLLayer, manual_cast.LayerNorm): def forward_ggml_cast_weights(self, input): if self.weight is None: return super().forward_comfy_cast_weights(input) weight, bias = self.cast_bias_weight(input) return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) class GroupNorm(GGMLLayer, manual_cast.GroupNorm): def forward_ggml_cast_weights(self, input): weight, bias = self.cast_bias_weight(input) return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) def move_patch_to_device(item, device): if isinstance(item, torch.Tensor): return item.to(device, non_blocking=True) elif isinstance(item, tuple): return tuple(move_patch_to_device(x, device) for x in item) elif isinstance(item, list): return [move_patch_to_device(x, device) for x in item] else: return item