mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
1344 lines
50 KiB
Python
1344 lines
50 KiB
Python
"""
|
|
Copyright 2025 "City96" and Benjanin Berman
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import argparse
|
|
import logging
|
|
import os
|
|
import warnings
|
|
import numpy as np
|
|
import re
|
|
|
|
import gguf
|
|
import torch
|
|
from safetensors.torch import load_file, save_file
|
|
from sentencepiece import sentencepiece_model_pb2 as model
|
|
from tqdm import tqdm
|
|
|
|
from .lora import calculate_weight
|
|
from .model_management import device_supports_non_blocking
|
|
from .ops import cast_to, manual_cast
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
QUANTIZATION_THRESHOLD = 1024
|
|
REARRANGE_THRESHOLD = 512
|
|
MAX_TENSOR_NAME_LENGTH = 127
|
|
MAX_TENSOR_DIMS = 4
|
|
TORCH_COMPATIBLE_QTYPES = (None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16)
|
|
IMG_ARCH_LIST = {"flux", "sd1", "sdxl", "sd3", "aura", "hidream", "cosmos", "ltxv", "hyvid", "wan", "qwen_image"}
|
|
TXT_ARCH_LIST = {"t5", "t5encoder", "llama", "qwen2vl"}
|
|
|
|
CLIP_VISION_SD_MAP = {
|
|
"mm.": "visual.merger.mlp.",
|
|
"v.post_ln.": "visual.merger.ln_q.",
|
|
"v.patch_embd": "visual.patch_embed.proj",
|
|
"v.blk.": "visual.blocks.",
|
|
"ffn_up": "mlp.up_proj",
|
|
"ffn_down": "mlp.down_proj",
|
|
"ffn_gate": "mlp.gate_proj",
|
|
"attn_out.": "attn.proj.",
|
|
"ln1.": "norm1.",
|
|
"ln2.": "norm2.",
|
|
}
|
|
|
|
|
|
class ModelTemplate:
|
|
arch = "invalid" # string describing architecture
|
|
shape_fix = False # whether to reshape tensors
|
|
keys_detect = [] # list of lists to match in state dict
|
|
keys_banned = [] # list of keys that should mark model as invalid for conversion
|
|
keys_hiprec = [] # list of keys that need to be kept in fp32 for some reason
|
|
keys_ignore = [] # list of strings to ignore keys by when found
|
|
|
|
def handle_nd_tensor(self, key, data):
|
|
raise NotImplementedError(f"Tensor detected that exceeds dims supported by C++ code! ({key} @ {data.shape})")
|
|
|
|
|
|
class ModelFlux(ModelTemplate):
|
|
arch = "flux"
|
|
keys_detect = [
|
|
("transformer_blocks.0.attn.norm_added_k.weight",),
|
|
("double_blocks.0.img_attn.proj.weight",),
|
|
]
|
|
keys_banned = ["transformer_blocks.0.attn.norm_added_k.weight", ]
|
|
|
|
|
|
class ModelSD3(ModelTemplate):
|
|
arch = "sd3"
|
|
keys_detect = [
|
|
("transformer_blocks.0.attn.add_q_proj.weight",),
|
|
("joint_blocks.0.x_block.attn.qkv.weight",),
|
|
]
|
|
keys_banned = ["transformer_blocks.0.attn.add_q_proj.weight", ]
|
|
|
|
|
|
class ModelAura(ModelTemplate):
|
|
arch = "aura"
|
|
keys_detect = [
|
|
("double_layers.3.modX.1.weight",),
|
|
("joint_transformer_blocks.3.ff_context.out_projection.weight",),
|
|
]
|
|
keys_banned = ["joint_transformer_blocks.3.ff_context.out_projection.weight", ]
|
|
|
|
|
|
class ModelHiDream(ModelTemplate):
|
|
arch = "hidream"
|
|
keys_detect = [
|
|
(
|
|
"caption_projection.0.linear.weight",
|
|
"double_stream_blocks.0.block.ff_i.shared_experts.w3.weight"
|
|
)
|
|
]
|
|
keys_hiprec = [
|
|
# nn.parameter, can't load from BF16 ver
|
|
".ff_i.gate.weight",
|
|
"img_emb.emb_pos"
|
|
]
|
|
|
|
|
|
class CosmosPredict2(ModelTemplate):
|
|
arch = "cosmos"
|
|
keys_detect = [
|
|
(
|
|
"blocks.0.mlp.layer1.weight",
|
|
"blocks.0.adaln_modulation_cross_attn.1.weight",
|
|
)
|
|
]
|
|
keys_hiprec = ["pos_embedder"]
|
|
keys_ignore = ["_extra_state", "accum_"]
|
|
|
|
|
|
class ModelHyVid(ModelTemplate):
|
|
arch = "hyvid"
|
|
keys_detect = [
|
|
(
|
|
"double_blocks.0.img_attn_proj.weight",
|
|
"txt_in.individual_token_refiner.blocks.1.self_attn_qkv.weight",
|
|
)
|
|
]
|
|
|
|
def handle_nd_tensor(self, key, data):
|
|
# hacky but don't have any better ideas
|
|
path = f"./fix_5d_tensors_{self.arch}.safetensors" # TODO: somehow get a path here??
|
|
if os.path.isfile(path):
|
|
raise RuntimeError(f"5D tensor fix file already exists! {path}")
|
|
fsd = {key: torch.from_numpy(data)}
|
|
tqdm.write(f"5D key found in state dict! Manual fix required! - {key} {data.shape}")
|
|
save_file(fsd, path)
|
|
|
|
|
|
class ModelWan(ModelHyVid):
|
|
arch = "wan"
|
|
keys_detect = [
|
|
(
|
|
"blocks.0.self_attn.norm_q.weight",
|
|
"text_embedding.2.weight",
|
|
"head.modulation",
|
|
)
|
|
]
|
|
keys_hiprec = [
|
|
".modulation" # nn.parameter, can't load from BF16 ver
|
|
]
|
|
|
|
|
|
class ModelLTXV(ModelTemplate):
|
|
arch = "ltxv"
|
|
keys_detect = [
|
|
(
|
|
"adaln_single.emb.timestep_embedder.linear_2.weight",
|
|
"transformer_blocks.27.scale_shift_table",
|
|
"caption_projection.linear_2.weight",
|
|
)
|
|
]
|
|
keys_hiprec = [
|
|
"scale_shift_table" # nn.parameter, can't load from BF16 base quant
|
|
]
|
|
|
|
|
|
class ModelSDXL(ModelTemplate):
|
|
arch = "sdxl"
|
|
shape_fix = True
|
|
keys_detect = [
|
|
("down_blocks.0.downsamplers.0.conv.weight", "add_embedding.linear_1.weight",),
|
|
(
|
|
"input_blocks.3.0.op.weight", "input_blocks.6.0.op.weight",
|
|
"output_blocks.2.2.conv.weight", "output_blocks.5.2.conv.weight",
|
|
), # Non-diffusers
|
|
("label_emb.0.0.weight",),
|
|
]
|
|
|
|
|
|
class ModelSD1(ModelTemplate):
|
|
arch = "sd1"
|
|
shape_fix = True
|
|
keys_detect = [
|
|
("down_blocks.0.downsamplers.0.conv.weight",),
|
|
(
|
|
"input_blocks.3.0.op.weight", "input_blocks.6.0.op.weight", "input_blocks.9.0.op.weight",
|
|
"output_blocks.2.1.conv.weight", "output_blocks.5.2.conv.weight", "output_blocks.8.2.conv.weight"
|
|
), # Non-diffusers
|
|
]
|
|
|
|
|
|
# The architectures are checked in order and the first successful match terminates the search.
|
|
arch_list = [ModelFlux, ModelSD3, ModelAura, ModelHiDream, CosmosPredict2, ModelLTXV, ModelHyVid, ModelWan, ModelSDXL, ModelSD1]
|
|
|
|
|
|
def is_model_arch(model, state_dict):
|
|
# check if model is correct
|
|
matched = False
|
|
invalid = False
|
|
for match_list in model.keys_detect:
|
|
if all(key in state_dict for key in match_list):
|
|
matched = True
|
|
invalid = any(key in state_dict for key in model.keys_banned)
|
|
break
|
|
assert not invalid, "Model architecture not allowed for conversion! (i.e. reference VS diffusers format)"
|
|
return matched
|
|
|
|
|
|
def detect_arch(state_dict):
|
|
model_arch = None
|
|
for arch in arch_list:
|
|
if is_model_arch(arch, state_dict):
|
|
model_arch = arch()
|
|
break
|
|
assert model_arch is not None, "Unknown model architecture!"
|
|
return model_arch
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Generate F16 GGUF files from single UNET")
|
|
parser.add_argument("--src", required=True, help="Source model ckpt file.")
|
|
parser.add_argument("--dst", help="Output unet gguf file.")
|
|
args = parser.parse_args()
|
|
|
|
if not os.path.isfile(args.src):
|
|
parser.error("No input provided!")
|
|
|
|
return args
|
|
|
|
|
|
def strip_prefix(state_dict):
|
|
# prefix for mixed state dict
|
|
prefix = None
|
|
for pfx in ["model.diffusion_model.", "model."]:
|
|
if any([x.startswith(pfx) for x in state_dict.keys()]):
|
|
prefix = pfx
|
|
break
|
|
|
|
# prefix for uniform state dict
|
|
if prefix is None:
|
|
for pfx in ["net."]:
|
|
if all([x.startswith(pfx) for x in state_dict.keys()]):
|
|
prefix = pfx
|
|
break
|
|
|
|
# strip prefix if found
|
|
if prefix is not None:
|
|
logger.info(f"State dict prefix found: '{prefix}'")
|
|
sd = {}
|
|
for k, v in state_dict.items():
|
|
if prefix not in k:
|
|
continue
|
|
k = k.replace(prefix, "")
|
|
sd[k] = v
|
|
else:
|
|
logger.debug("State dict has no prefix")
|
|
sd = state_dict
|
|
|
|
return sd
|
|
|
|
|
|
def load_state_dict(path):
|
|
if any(path.endswith(x) for x in [".ckpt", ".pt", ".bin", ".pth"]):
|
|
state_dict = torch.load(path, map_location="cpu", weights_only=True)
|
|
for subkey in ["model", "module"]:
|
|
if subkey in state_dict:
|
|
state_dict = state_dict[subkey]
|
|
break
|
|
if len(state_dict) < 20:
|
|
raise RuntimeError(f"pt subkey load failed: {state_dict.keys()}")
|
|
else:
|
|
state_dict = load_file(path)
|
|
|
|
return strip_prefix(state_dict)
|
|
|
|
|
|
def handle_tensors(writer, state_dict, model_arch):
|
|
name_lengths = tuple(sorted(
|
|
((key, len(key)) for key in state_dict.keys()),
|
|
key=lambda item: item[1],
|
|
reverse=True,
|
|
))
|
|
if not name_lengths:
|
|
return
|
|
max_name_len = name_lengths[0][1]
|
|
if max_name_len > MAX_TENSOR_NAME_LENGTH:
|
|
bad_list = ", ".join(f"{key!r} ({namelen})" for key, namelen in name_lengths if namelen > MAX_TENSOR_NAME_LENGTH)
|
|
raise ValueError(f"Can only handle tensor names up to {MAX_TENSOR_NAME_LENGTH} characters. Tensors exceeding the limit: {bad_list}")
|
|
for key, data in tqdm(state_dict.items()):
|
|
old_dtype = data.dtype
|
|
|
|
if any(x in key for x in model_arch.keys_ignore):
|
|
tqdm.write(f"Filtering ignored key: '{key}'")
|
|
continue
|
|
|
|
if data.dtype == torch.bfloat16:
|
|
data = data.to(torch.float32).numpy()
|
|
# this is so we don't break torch 2.0.X
|
|
elif data.dtype in [getattr(torch, "float8_e4m3fn", "_invalid"), getattr(torch, "float8_e5m2", "_invalid")]:
|
|
data = data.to(torch.float16).numpy()
|
|
else:
|
|
data = data.numpy()
|
|
|
|
n_dims = len(data.shape)
|
|
data_shape = data.shape
|
|
if old_dtype == torch.bfloat16:
|
|
data_qtype = gguf.GGMLQuantizationType.BF16
|
|
# elif old_dtype == torch.float32:
|
|
# data_qtype = gguf.GGMLQuantizationType.F32
|
|
else:
|
|
data_qtype = gguf.GGMLQuantizationType.F16
|
|
|
|
# The max no. of dimensions that can be handled by the quantization code is 4
|
|
if len(data.shape) > MAX_TENSOR_DIMS:
|
|
model_arch.handle_nd_tensor(key, data)
|
|
continue # needs to be added back later
|
|
|
|
# get number of parameters (AKA elements) in this tensor
|
|
n_params = 1
|
|
for dim_size in data_shape:
|
|
n_params *= dim_size
|
|
|
|
if old_dtype in (torch.float32, torch.bfloat16):
|
|
if n_dims == 1:
|
|
# one-dimensional tensors should be kept in F32
|
|
# also speeds up inference due to not dequantizing
|
|
data_qtype = gguf.GGMLQuantizationType.F32
|
|
|
|
elif n_params <= QUANTIZATION_THRESHOLD:
|
|
# very small tensors
|
|
data_qtype = gguf.GGMLQuantizationType.F32
|
|
|
|
elif any(x in key for x in model_arch.keys_hiprec):
|
|
# tensors that require max precision
|
|
data_qtype = gguf.GGMLQuantizationType.F32
|
|
|
|
if (model_arch.shape_fix # NEVER reshape for models such as flux
|
|
and n_dims > 1 # Skip one-dimensional tensors
|
|
and n_params >= REARRANGE_THRESHOLD # Only rearrange tensors meeting the size requirement
|
|
and (n_params / 256).is_integer() # Rearranging only makes sense if total elements is divisible by 256
|
|
and not (data.shape[-1] / 256).is_integer() # Only need to rearrange if the last dimension is not divisible by 256
|
|
):
|
|
orig_shape = data.shape
|
|
data = data.reshape(n_params // 256, 256)
|
|
writer.add_array(f"comfy.gguf.orig_shape.{key}", tuple(int(dim) for dim in orig_shape))
|
|
|
|
try:
|
|
data = gguf.quants.quantize(data, data_qtype)
|
|
except (AttributeError, gguf.QuantError) as e:
|
|
tqdm.write(f"falling back to F16: {e}")
|
|
data_qtype = gguf.GGMLQuantizationType.F16
|
|
data = gguf.quants.quantize(data, data_qtype)
|
|
|
|
new_name = key # do we need to rename?
|
|
|
|
shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}"
|
|
tqdm.write(f"{f'%-{max_name_len + 4}s' % f'{new_name}'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
|
|
|
|
writer.add_tensor(new_name, data, raw_dtype=data_qtype)
|
|
|
|
|
|
def convert_file(path, dst_path=None, interact=True, overwrite=False):
|
|
# load & run model detection logic
|
|
state_dict = load_state_dict(path)
|
|
model_arch = detect_arch(state_dict)
|
|
logger.info(f"* Architecture detected from input: {model_arch.arch}")
|
|
|
|
# detect & set dtype for output file
|
|
dtypes = [x.dtype for x in state_dict.values()]
|
|
dtypes = {x: dtypes.count(x) for x in set(dtypes)}
|
|
main_dtype = max(dtypes, key=dtypes.get)
|
|
|
|
if main_dtype == torch.bfloat16:
|
|
ftype_name = "BF16"
|
|
ftype_gguf = gguf.LlamaFileType.MOSTLY_BF16
|
|
# elif main_dtype == torch.float32:
|
|
# ftype_name = "F32"
|
|
# ftype_gguf = None
|
|
else:
|
|
ftype_name = "F16"
|
|
ftype_gguf = gguf.LlamaFileType.MOSTLY_F16
|
|
|
|
if dst_path is None:
|
|
dst_path = f"{os.path.splitext(path)[0]}-{ftype_name}.gguf"
|
|
elif "{ftype}" in dst_path: # lcpp logic
|
|
dst_path = dst_path.replace("{ftype}", ftype_name)
|
|
|
|
if os.path.isfile(dst_path) and not overwrite:
|
|
if interact:
|
|
input("Output exists enter to continue or ctrl+c to abort!")
|
|
else:
|
|
raise OSError("Output exists and overwriting is disabled!")
|
|
|
|
# handle actual file
|
|
writer = gguf.GGUFWriter(path=None, arch=model_arch.arch)
|
|
writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
|
|
if ftype_gguf is not None:
|
|
writer.add_file_type(ftype_gguf)
|
|
|
|
handle_tensors(writer, state_dict, model_arch)
|
|
writer.write_header_to_file(path=dst_path)
|
|
writer.write_kv_data_to_file()
|
|
writer.write_tensors_to_file(progress=True)
|
|
writer.close()
|
|
|
|
fix = f"./fix_5d_tensors_{model_arch.arch}.safetensors"
|
|
if os.path.isfile(fix):
|
|
logger.warning(f"\n### Warning! Fix file found at '{fix}'")
|
|
logger.warning(" you most likely need to run 'fix_5d_tensors.py' after quantization.")
|
|
|
|
return dst_path, model_arch
|
|
|
|
|
|
def is_torch_compatible(tensor):
|
|
return tensor is None or getattr(tensor, "tensor_type", None) in TORCH_COMPATIBLE_QTYPES
|
|
|
|
|
|
def is_quantized(tensor):
|
|
return not is_torch_compatible(tensor)
|
|
|
|
|
|
def dequantize_tensor(tensor, dtype=None, dequant_dtype=None):
|
|
qtype = getattr(tensor, "tensor_type", None)
|
|
oshape = getattr(tensor, "tensor_shape", tensor.shape)
|
|
|
|
if qtype in TORCH_COMPATIBLE_QTYPES:
|
|
return tensor.to(dtype)
|
|
elif qtype in dequantize_functions:
|
|
dequant_dtype = dtype if dequant_dtype == "target" else dequant_dtype
|
|
return dequantize(tensor.data, qtype, oshape, dtype=dequant_dtype).to(dtype)
|
|
else:
|
|
# this is incredibly slow
|
|
tqdm.write(f"Falling back to numpy dequant for qtype: {getattr(qtype, 'name', repr(qtype))}")
|
|
new = gguf.quants.dequantize(tensor.cpu().numpy(), qtype)
|
|
return torch.from_numpy(new).to(tensor.device, dtype=dtype)
|
|
|
|
|
|
def dequantize(data, qtype, oshape, dtype=None):
|
|
"""
|
|
Dequantize tensor back to usable shape/dtype
|
|
"""
|
|
block_size, type_size = gguf.GGML_QUANT_SIZES[qtype]
|
|
dequantize_blocks = dequantize_functions[qtype]
|
|
|
|
rows = data.reshape(
|
|
(-1, data.shape[-1])
|
|
).view(torch.uint8)
|
|
|
|
n_blocks = rows.numel() // type_size
|
|
blocks = rows.reshape((n_blocks, type_size))
|
|
blocks = dequantize_blocks(blocks, block_size, type_size, dtype)
|
|
return blocks.reshape(oshape)
|
|
|
|
|
|
def to_uint32(x):
|
|
# no uint32 :(
|
|
x = x.view(torch.uint8).to(torch.int32)
|
|
return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1)
|
|
|
|
|
|
def split_block_dims(blocks, *args):
|
|
n_max = blocks.shape[1]
|
|
dims = list(args) + [n_max - sum(args)]
|
|
return torch.split(blocks, dims, dim=1)
|
|
|
|
|
|
# Full weights #
|
|
def dequantize_blocks_BF16(blocks, block_size, type_size, dtype=None):
|
|
return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32)
|
|
|
|
|
|
# Legacy Quants #
|
|
def dequantize_blocks_Q8_0(blocks, block_size, type_size, dtype=None):
|
|
d, x = split_block_dims(blocks, 2)
|
|
d = d.view(torch.float16).to(dtype)
|
|
x = x.view(torch.int8)
|
|
return (d * x)
|
|
|
|
|
|
def dequantize_blocks_Q5_1(blocks, block_size, type_size, dtype=None):
|
|
n_blocks = blocks.shape[0]
|
|
|
|
d, m, qh, qs = split_block_dims(blocks, 2, 2, 4)
|
|
d = d.view(torch.float16).to(dtype)
|
|
m = m.view(torch.float16).to(dtype)
|
|
qh = to_uint32(qh)
|
|
|
|
qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
|
|
ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1)
|
|
qh = (qh & 1).to(torch.uint8)
|
|
ql = (ql & 0x0F).reshape((n_blocks, -1))
|
|
|
|
qs = (ql | (qh << 4))
|
|
return (d * qs) + m
|
|
|
|
|
|
def dequantize_blocks_Q5_0(blocks, block_size, type_size, dtype=None):
|
|
n_blocks = blocks.shape[0]
|
|
|
|
d, qh, qs = split_block_dims(blocks, 2, 4)
|
|
d = d.view(torch.float16).to(dtype)
|
|
qh = to_uint32(qh)
|
|
|
|
qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
|
|
ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1)
|
|
|
|
qh = (qh & 1).to(torch.uint8)
|
|
ql = (ql & 0x0F).reshape(n_blocks, -1)
|
|
|
|
qs = (ql | (qh << 4)).to(torch.int8) - 16
|
|
return (d * qs)
|
|
|
|
|
|
def dequantize_blocks_Q4_1(blocks, block_size, type_size, dtype=None):
|
|
n_blocks = blocks.shape[0]
|
|
|
|
d, m, qs = split_block_dims(blocks, 2, 2)
|
|
d = d.view(torch.float16).to(dtype)
|
|
m = m.view(torch.float16).to(dtype)
|
|
|
|
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1)
|
|
qs = (qs & 0x0F).reshape(n_blocks, -1)
|
|
|
|
return (d * qs) + m
|
|
|
|
|
|
def dequantize_blocks_Q4_0(blocks, block_size, type_size, dtype=None):
|
|
n_blocks = blocks.shape[0]
|
|
|
|
d, qs = split_block_dims(blocks, 2)
|
|
d = d.view(torch.float16).to(dtype)
|
|
|
|
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
|
|
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8
|
|
return (d * qs)
|
|
|
|
|
|
# K Quants #
|
|
QK_K = 256
|
|
K_SCALE_SIZE = 12
|
|
|
|
|
|
def get_scale_min(scales):
|
|
n_blocks = scales.shape[0]
|
|
scales = scales.view(torch.uint8)
|
|
scales = scales.reshape((n_blocks, 3, 4))
|
|
|
|
d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2)
|
|
|
|
sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1)
|
|
min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1)
|
|
|
|
return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8)))
|
|
|
|
|
|
def dequantize_blocks_Q6_K(blocks, block_size, type_size, dtype=None):
|
|
n_blocks = blocks.shape[0]
|
|
|
|
ql, qh, scales, d, = split_block_dims(blocks, QK_K // 2, QK_K // 4, QK_K // 16)
|
|
|
|
scales = scales.view(torch.int8).to(dtype)
|
|
d = d.view(torch.float16).to(dtype)
|
|
d = (d * scales).reshape((n_blocks, QK_K // 16, 1))
|
|
|
|
ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
|
|
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
|
|
qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1))
|
|
qh = (qh & 0x03).reshape((n_blocks, -1, 32))
|
|
q = (ql | (qh << 4)).to(torch.int8) - 32
|
|
q = q.reshape((n_blocks, QK_K // 16, -1))
|
|
|
|
return (d * q).reshape((n_blocks, QK_K))
|
|
|
|
|
|
def dequantize_blocks_Q5_K(blocks, block_size, type_size, dtype=None):
|
|
n_blocks = blocks.shape[0]
|
|
|
|
d, dmin, scales, qh, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE, QK_K // 8)
|
|
|
|
d = d.view(torch.float16).to(dtype)
|
|
dmin = dmin.view(torch.float16).to(dtype)
|
|
|
|
sc, m = get_scale_min(scales)
|
|
|
|
d = (d * sc).reshape((n_blocks, -1, 1))
|
|
dm = (dmin * m).reshape((n_blocks, -1, 1))
|
|
|
|
ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
|
|
qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([i for i in range(8)], device=d.device, dtype=torch.uint8).reshape((1, 1, 8, 1))
|
|
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
|
|
qh = (qh & 0x01).reshape((n_blocks, -1, 32))
|
|
q = (ql | (qh << 4))
|
|
|
|
return (d * q - dm).reshape((n_blocks, QK_K))
|
|
|
|
|
|
def dequantize_blocks_Q4_K(blocks, block_size, type_size, dtype=None):
|
|
n_blocks = blocks.shape[0]
|
|
|
|
d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE)
|
|
d = d.view(torch.float16).to(dtype)
|
|
dmin = dmin.view(torch.float16).to(dtype)
|
|
|
|
sc, m = get_scale_min(scales)
|
|
|
|
d = (d * sc).reshape((n_blocks, -1, 1))
|
|
dm = (dmin * m).reshape((n_blocks, -1, 1))
|
|
|
|
qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
|
|
qs = (qs & 0x0F).reshape((n_blocks, -1, 32))
|
|
|
|
return (d * qs - dm).reshape((n_blocks, QK_K))
|
|
|
|
|
|
def dequantize_blocks_Q3_K(blocks, block_size, type_size, dtype=None):
|
|
n_blocks = blocks.shape[0]
|
|
|
|
hmask, qs, scales, d = split_block_dims(blocks, QK_K // 8, QK_K // 4, 12)
|
|
d = d.view(torch.float16).to(dtype)
|
|
|
|
lscales, hscales = scales[:, :8], scales[:, 8:]
|
|
lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 2, 1))
|
|
lscales = lscales.reshape((n_blocks, 16))
|
|
hscales = hscales.reshape((n_blocks, 1, 4)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 4, 1))
|
|
hscales = hscales.reshape((n_blocks, 16))
|
|
scales = (lscales & 0x0F) | ((hscales & 0x03) << 4)
|
|
scales = (scales.to(torch.int8) - 32)
|
|
|
|
dl = (d * scales).reshape((n_blocks, 16, 1))
|
|
|
|
ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1))
|
|
qh = hmask.reshape(n_blocks, -1, 1, 32) >> torch.tensor([i for i in range(8)], device=d.device, dtype=torch.uint8).reshape((1, 1, 8, 1))
|
|
ql = ql.reshape((n_blocks, 16, QK_K // 16)) & 3
|
|
qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & 1) ^ 1
|
|
q = (ql.to(torch.int8) - (qh << 2).to(torch.int8))
|
|
|
|
return (dl * q).reshape((n_blocks, QK_K))
|
|
|
|
|
|
def dequantize_blocks_Q2_K(blocks, block_size, type_size, dtype=None):
|
|
n_blocks = blocks.shape[0]
|
|
|
|
scales, qs, d, dmin = split_block_dims(blocks, QK_K // 16, QK_K // 4, 2)
|
|
d = d.view(torch.float16).to(dtype)
|
|
dmin = dmin.view(torch.float16).to(dtype)
|
|
|
|
# (n_blocks, 16, 1)
|
|
dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1))
|
|
ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1))
|
|
|
|
shift = torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1))
|
|
|
|
qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & 3
|
|
qs = qs.reshape((n_blocks, QK_K // 16, 16))
|
|
qs = dl * qs - ml
|
|
|
|
return qs.reshape((n_blocks, -1))
|
|
|
|
|
|
dequantize_functions = {
|
|
gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16,
|
|
gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0,
|
|
gguf.GGMLQuantizationType.Q5_1: dequantize_blocks_Q5_1,
|
|
gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0,
|
|
gguf.GGMLQuantizationType.Q4_1: dequantize_blocks_Q4_1,
|
|
gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0,
|
|
gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K,
|
|
gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K,
|
|
gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K,
|
|
gguf.GGMLQuantizationType.Q3_K: dequantize_blocks_Q3_K,
|
|
gguf.GGMLQuantizationType.Q2_K: dequantize_blocks_Q2_K,
|
|
}
|
|
|
|
|
|
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
|
|
|
|
|
|
def get_orig_shape(reader, tensor_name):
|
|
field_key = f"comfy.gguf.orig_shape.{tensor_name}"
|
|
field = reader.get_field(field_key)
|
|
if field is None:
|
|
return None
|
|
# Has original shape metadata, so we try to decode it.
|
|
if len(field.types) != 2 or field.types[0] != gguf.GGUFValueType.ARRAY or field.types[1] != gguf.GGUFValueType.INT32:
|
|
raise TypeError(f"Bad original shape metadata for {field_key}: Expected ARRAY of INT32, got {field.types}")
|
|
return torch.Size(tuple(int(field.parts[part_idx][0]) for part_idx in field.data))
|
|
|
|
|
|
def get_field(reader, field_name, field_type):
|
|
field = reader.get_field(field_name)
|
|
if field is None:
|
|
return None
|
|
elif field_type == str:
|
|
# extra check here as this is used for checking arch string
|
|
if len(field.types) != 1 or field.types[0] != gguf.GGUFValueType.STRING:
|
|
raise TypeError(f"Bad type for GGUF {field_name} key: expected string, got {field.types!r}")
|
|
return str(field.parts[field.data[-1]], encoding="utf-8")
|
|
elif field_type in [int, float, bool]:
|
|
return field_type(field.parts[field.data[-1]])
|
|
else:
|
|
raise TypeError(f"Unknown field type {field_type}")
|
|
|
|
|
|
def get_list_field(reader, field_name, field_type):
|
|
field = reader.get_field(field_name)
|
|
if field is None:
|
|
return None
|
|
elif field_type == str:
|
|
return tuple(str(field.parts[part_idx], encoding="utf-8") for part_idx in field.data)
|
|
elif field_type in [int, float, bool]:
|
|
return tuple(field_type(field.parts[part_idx][0]) for part_idx in field.data)
|
|
else:
|
|
raise TypeError(f"Unknown field type {field_type}")
|
|
|
|
|
|
def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=False, is_text_model=False) -> dict:
|
|
"""
|
|
Read state dict as fake tensors
|
|
"""
|
|
reader = gguf.GGUFReader(path)
|
|
|
|
# filter and strip prefix
|
|
has_prefix = False
|
|
if handle_prefix is not None:
|
|
prefix_len = len(handle_prefix)
|
|
tensor_names = set(tensor.name for tensor in reader.tensors)
|
|
has_prefix = any(s.startswith(handle_prefix) for s in tensor_names)
|
|
else:
|
|
prefix_len = 0
|
|
|
|
tensors = []
|
|
for tensor in reader.tensors:
|
|
sd_key = tensor_name = tensor.name
|
|
if has_prefix:
|
|
if not tensor_name.startswith(handle_prefix):
|
|
continue
|
|
sd_key = tensor_name[prefix_len:]
|
|
tensors.append((sd_key, tensor))
|
|
|
|
# detect and verify architecture
|
|
compat = None
|
|
arch_str = get_field(reader, "general.architecture", str)
|
|
if arch_str in [None, "pig"]:
|
|
if is_text_model:
|
|
raise ValueError(f"This text model is incompatible with llama.cpp!\nConsider using the safetensors version\n({path})")
|
|
compat = "sd.cpp" if arch_str is None else arch_str
|
|
# import here to avoid changes to convert.py breaking regular models
|
|
try:
|
|
arch_str = detect_arch(set(val[0] for val in tensors)).arch
|
|
except Exception as e:
|
|
raise ValueError(f"This model is not currently supported - ({e})")
|
|
elif arch_str not in TXT_ARCH_LIST and is_text_model:
|
|
logger.warning(f"Unexpected text model architecture type in GGUF file: {arch_str!r}")
|
|
elif arch_str not in IMG_ARCH_LIST and not is_text_model:
|
|
logger.warning(f"Unexpected architecture type in GGUF file: {arch_str!r}")
|
|
|
|
if compat:
|
|
logger.warning(f"Warning: This gguf model file is loaded in compatibility mode '{compat}' [arch:{arch_str}]")
|
|
|
|
# main loading loop
|
|
state_dict = {}
|
|
qtype_dict = {}
|
|
for sd_key, tensor in tensors:
|
|
tensor_name = tensor.name
|
|
# torch_tensor = torch.from_numpy(tensor.data) # mmap
|
|
|
|
# NOTE: line above replaced with this block to avoid persistent numpy warning about mmap
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings("ignore", message="The given NumPy array is not writable")
|
|
torch_tensor = torch.from_numpy(tensor.data) # mmap
|
|
|
|
shape = get_orig_shape(reader, tensor_name)
|
|
if shape is None:
|
|
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
|
|
# Workaround for stable-diffusion.cpp SDXL detection.
|
|
if compat == "sd.cpp" and arch_str == "sdxl":
|
|
if any([tensor_name.endswith(x) for x in (".proj_in.weight", ".proj_out.weight")]):
|
|
while len(shape) > 2 and shape[-1] == 1:
|
|
shape = shape[:-1]
|
|
|
|
# add to state dict
|
|
if tensor.tensor_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}:
|
|
torch_tensor = torch_tensor.view(*shape)
|
|
state_dict[sd_key] = GGMLTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape)
|
|
|
|
# keep track of loaded tensor types
|
|
tensor_type_str = getattr(tensor.tensor_type, "name", repr(tensor.tensor_type))
|
|
qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1
|
|
|
|
# print loaded tensor type counts
|
|
logger.debug("gguf qtypes: " + ", ".join(f"{k} ({v})" for k, v in qtype_dict.items()))
|
|
|
|
# mark largest tensor for vram estimation
|
|
qsd = {k: v for k, v in state_dict.items() if is_quantized(v)}
|
|
if len(qsd) > 0:
|
|
max_key = max(qsd.keys(), key=lambda k: qsd[k].numel())
|
|
state_dict[max_key].is_largest_weight = True
|
|
|
|
if return_arch:
|
|
return (state_dict, arch_str)
|
|
return state_dict
|
|
|
|
|
|
# for remapping llama.cpp -> original key names
|
|
T5_SD_MAP = {
|
|
"enc.": "encoder.",
|
|
".blk.": ".block.",
|
|
"token_embd": "shared",
|
|
"output_norm": "final_layer_norm",
|
|
"attn_q": "layer.0.SelfAttention.q",
|
|
"attn_k": "layer.0.SelfAttention.k",
|
|
"attn_v": "layer.0.SelfAttention.v",
|
|
"attn_o": "layer.0.SelfAttention.o",
|
|
"attn_norm": "layer.0.layer_norm",
|
|
"attn_rel_b": "layer.0.SelfAttention.relative_attention_bias",
|
|
"ffn_up": "layer.1.DenseReluDense.wi_1",
|
|
"ffn_down": "layer.1.DenseReluDense.wo",
|
|
"ffn_gate": "layer.1.DenseReluDense.wi_0",
|
|
"ffn_norm": "layer.1.layer_norm",
|
|
}
|
|
|
|
LLAMA_SD_MAP = {
|
|
"blk.": "model.layers.",
|
|
"attn_norm": "input_layernorm",
|
|
"attn_q": "self_attn.q_proj",
|
|
"attn_k": "self_attn.k_proj",
|
|
"attn_v": "self_attn.v_proj",
|
|
"attn_output": "self_attn.o_proj",
|
|
"ffn_up": "mlp.up_proj",
|
|
"ffn_down": "mlp.down_proj",
|
|
"ffn_gate": "mlp.gate_proj",
|
|
"ffn_norm": "post_attention_layernorm",
|
|
"token_embd": "model.embed_tokens",
|
|
"output_norm": "model.norm",
|
|
"output.weight": "lm_head.weight",
|
|
}
|
|
|
|
|
|
def sd_map_replace(raw_sd, key_map):
|
|
sd = {}
|
|
for k, v in raw_sd.items():
|
|
for s, d in key_map.items():
|
|
k = k.replace(s, d)
|
|
sd[k] = v
|
|
return sd
|
|
|
|
|
|
def llama_permute(raw_sd, n_head, n_head_kv):
|
|
# Reverse version of LlamaModel.permute in llama.cpp convert script
|
|
sd = {}
|
|
permute = lambda x, h: x.reshape(h, x.shape[0] // h // 2, 2, *x.shape[1:]).swapaxes(1, 2).reshape(x.shape)
|
|
for k, v in raw_sd.items():
|
|
if k.endswith(("q_proj.weight", "q_proj.bias")):
|
|
v.data = permute(v.data, n_head)
|
|
if k.endswith(("k_proj.weight", "k_proj.bias")):
|
|
v.data = permute(v.data, n_head_kv)
|
|
sd[k] = v
|
|
return sd
|
|
|
|
|
|
def gguf_tokenizer_loader(path, temb_shape):
|
|
# convert gguf tokenizer to spiece
|
|
logger.info("Attempting to recreate sentencepiece tokenizer from GGUF file metadata...")
|
|
spm = model.ModelProto()
|
|
|
|
reader = gguf.GGUFReader(path)
|
|
|
|
if get_field(reader, "tokenizer.ggml.model", str) == "t5":
|
|
if temb_shape == (256384, 4096): # probably UMT5
|
|
spm.trainer_spec.model_type == 1 # Unigram (do we have a T5 w/ BPE?)
|
|
else:
|
|
raise NotImplementedError("Unknown model, can't set tokenizer!")
|
|
else:
|
|
raise NotImplementedError("Unknown model, can't set tokenizer!")
|
|
|
|
spm.normalizer_spec.add_dummy_prefix = get_field(reader, "tokenizer.ggml.add_space_prefix", bool)
|
|
spm.normalizer_spec.remove_extra_whitespaces = get_field(reader, "tokenizer.ggml.remove_extra_whitespaces", bool)
|
|
|
|
tokens = get_list_field(reader, "tokenizer.ggml.tokens", str)
|
|
scores = get_list_field(reader, "tokenizer.ggml.scores", float)
|
|
toktypes = get_list_field(reader, "tokenizer.ggml.token_type", int)
|
|
|
|
for idx, (token, score, toktype) in enumerate(zip(tokens, scores, toktypes)):
|
|
# # These aren't present in the original?
|
|
# if toktype == 5 and idx >= temb_shape[0]%1000):
|
|
# continue
|
|
|
|
piece = spm.SentencePiece()
|
|
piece.piece = token
|
|
piece.score = score
|
|
piece.type = toktype
|
|
spm.pieces.append(piece)
|
|
|
|
# unsure if any of these are correct
|
|
spm.trainer_spec.byte_fallback = True
|
|
spm.trainer_spec.vocab_size = len(tokens) # split off unused?
|
|
spm.trainer_spec.max_sentence_length = 4096
|
|
spm.trainer_spec.eos_id = get_field(reader, "tokenizer.ggml.eos_token_id", int)
|
|
spm.trainer_spec.pad_id = get_field(reader, "tokenizer.ggml.padding_token_id", int)
|
|
|
|
logger.info(f"Created tokenizer with vocab size of {len(spm.pieces)}")
|
|
del reader
|
|
return torch.ByteTensor(list(spm.SerializeToString()))
|
|
|
|
|
|
def strip_quant_suffix(name):
|
|
pattern = r"[-_]?(?:ud-)?i?q[0-9]_[a-z0-9_\-]{1,8}$"
|
|
match = re.search(pattern, name, re.IGNORECASE)
|
|
if match:
|
|
name = name[:match.start()]
|
|
return name
|
|
|
|
|
|
def gguf_mmproj_loader(path):
|
|
# Reverse version of Qwen2VLVisionModel.modify_tensors
|
|
logger.info("Attempting to find mmproj file for text encoder...")
|
|
|
|
# get name to match w/o quant suffix
|
|
tenc_fname = os.path.basename(path)
|
|
tenc = os.path.splitext(tenc_fname)[0].lower()
|
|
tenc = strip_quant_suffix(tenc)
|
|
|
|
# try and find matching mmproj
|
|
target = []
|
|
root = os.path.dirname(path)
|
|
for fname in os.listdir(root):
|
|
name, ext = os.path.splitext(fname)
|
|
if ext.lower() != ".gguf":
|
|
continue
|
|
if "mmproj" not in name.lower():
|
|
continue
|
|
if tenc in name.lower():
|
|
target.append(fname)
|
|
|
|
if len(target) == 0:
|
|
logger.error(f"Error: Can't find mmproj file for '{tenc_fname}' (matching:'{tenc}')! Qwen-Image-Edit will be broken!")
|
|
return {}
|
|
if len(target) > 1:
|
|
logger.error(f"Ambiguous mmproj for text encoder '{tenc_fname}', will use first match.")
|
|
|
|
logger.info(f"Using mmproj '{target[0]}' for text encoder '{tenc_fname}'.")
|
|
target = os.path.join(root, target[0])
|
|
vsd = gguf_sd_loader(target, is_text_model=True)
|
|
|
|
# concat 4D to 5D
|
|
if "v.patch_embd.weight.1" in vsd:
|
|
w1 = dequantize_tensor(vsd.pop("v.patch_embd.weight"), dtype=torch.float32)
|
|
w2 = dequantize_tensor(vsd.pop("v.patch_embd.weight.1"), dtype=torch.float32)
|
|
vsd["v.patch_embd.weight"] = torch.stack([w1, w2], dim=2)
|
|
|
|
# run main replacement
|
|
vsd = sd_map_replace(vsd, CLIP_VISION_SD_MAP)
|
|
|
|
# handle split Q/K/V
|
|
if "visual.blocks.0.attn_q.weight" in vsd:
|
|
attns = {}
|
|
# filter out attentions + group
|
|
for k,v in vsd.items():
|
|
if any(x in k for x in ["attn_q", "attn_k", "attn_v"]):
|
|
k_attn, k_name = k.rsplit(".attn_", 1)
|
|
k_attn += ".attn.qkv." + k_name.split(".")[-1]
|
|
if k_attn not in attns:
|
|
attns[k_attn] = {}
|
|
attns[k_attn][k_name] = dequantize_tensor(
|
|
v, dtype=(torch.bfloat16 if is_quantized(v) else torch.float16)
|
|
)
|
|
|
|
# recombine
|
|
for k,v in attns.items():
|
|
suffix = k.split(".")[-1]
|
|
vsd[k] = torch.cat([
|
|
v[f"q.{suffix}"],
|
|
v[f"k.{suffix}"],
|
|
v[f"v.{suffix}"],
|
|
], dim=0)
|
|
del attns
|
|
|
|
return vsd
|
|
|
|
|
|
def gguf_tekken_tokenizer_loader(path, temb_shape):
|
|
# convert ggml (hf) tokenizer metadata to tekken/comfy data
|
|
logger.info("Attempting to recreate tekken tokenizer from GGUF file metadata...")
|
|
import json
|
|
import base64
|
|
from transformers.convert_slow_tokenizer import bytes_to_unicode
|
|
|
|
reader = gguf.GGUFReader(path)
|
|
|
|
model_str = get_field(reader, "tokenizer.ggml.model", str)
|
|
if model_str == "gpt2":
|
|
if temb_shape == (131072, 5120): # probably Mistral
|
|
data = {
|
|
"config": {"num_vocab_tokens": 150000, "default_vocab_size": 131072},
|
|
"vocab": [],
|
|
"special_tokens": [],
|
|
}
|
|
else:
|
|
raise NotImplementedError("Unknown model, can't set tokenizer!")
|
|
else:
|
|
raise NotImplementedError("Unknown model, can't set tokenizer!")
|
|
|
|
tokens = get_list_field(reader, "tokenizer.ggml.tokens", str)
|
|
toktypes = get_list_field(reader, "tokenizer.ggml.token_type", int)
|
|
|
|
decoder = {v: k for k, v in bytes_to_unicode().items()}
|
|
for idx, (token, toktype) in enumerate(zip(tokens, toktypes)):
|
|
if toktype == 3:
|
|
data["special_tokens"].append(
|
|
{'rank': idx, 'token_str': token, 'is_control': True}
|
|
)
|
|
else:
|
|
tok = bytes([decoder[char] for char in token])
|
|
data["vocab"].append({
|
|
"rank": len(data["vocab"]),
|
|
"token_bytes": base64.b64encode(tok).decode("ascii"),
|
|
"token_str": tok.decode("utf-8", errors="replace") # ?
|
|
})
|
|
|
|
logger.info(f"Created tekken tokenizer with vocab size of {len(data['vocab'])} (+{len(data['special_tokens'])})")
|
|
del reader
|
|
return torch.ByteTensor(list(json.dumps(data).encode('utf-8')))
|
|
|
|
|
|
def gguf_clip_loader(path):
|
|
sd, arch = gguf_sd_loader(path, return_arch=True, is_text_model=True)
|
|
if arch in {"t5", "t5encoder"}:
|
|
temb_key = "token_embd.weight"
|
|
if temb_key in sd and sd[temb_key].shape == (256384, 4096):
|
|
# non-standard Comfy-Org tokenizer
|
|
sd["spiece_model"] = gguf_tokenizer_loader(path, sd[temb_key].shape)
|
|
# TODO: dequantizing token embed here is janky but otherwise we OOM due to tensor being massive.
|
|
logger.warning(f"Dequantizing {temb_key} to prevent runtime OOM.")
|
|
sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16)
|
|
sd = sd_map_replace(sd, T5_SD_MAP)
|
|
elif arch in {"llama", "qwen2vl"}:
|
|
# TODO: pass model_options["vocab_size"] to loader somehow
|
|
temb_key = "token_embd.weight"
|
|
if temb_key in sd and sd[temb_key].shape[0] >= (64 * 1024):
|
|
if arch == "llama" and sd[temb_key].shape == (131072, 5120):
|
|
# non-standard Comfy-Org tokenizer
|
|
sd["tekken_model"] = gguf_tekken_tokenizer_loader(path, sd[temb_key].shape)
|
|
# See note above for T5.
|
|
logger.warning(f"Dequantizing {temb_key} to prevent runtime OOM.")
|
|
sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16)
|
|
sd = sd_map_replace(sd, LLAMA_SD_MAP)
|
|
if arch == "llama":
|
|
sd = llama_permute(sd, 32, 8) # L3 / Mistral
|
|
if arch == "qwen2vl":
|
|
vsd = gguf_mmproj_loader(path)
|
|
sd.update(vsd)
|
|
else:
|
|
pass
|
|
return sd
|
|
|
|
|
|
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
|
|
|
|
|
|
def chained_hasattr(obj, chained_attr):
|
|
probe = obj
|
|
for attr in chained_attr.split('.'):
|
|
if hasattr(probe, attr):
|
|
probe = getattr(probe, attr)
|
|
else:
|
|
return False
|
|
return True
|
|
|
|
|
|
# A bakcward and forward compatible way to get `torch.compiler.disable`.
|
|
def get_torch_compiler_disable_decorator():
|
|
def dummy_decorator(*args, **kwargs):
|
|
def noop(x):
|
|
return x
|
|
|
|
return noop
|
|
|
|
from packaging import version
|
|
|
|
if not chained_hasattr(torch, "compiler.disable"):
|
|
logger.debug("ComfyUI-GGUF: Torch too old for torch.compile - bypassing")
|
|
return dummy_decorator # torch too old
|
|
elif version.parse(torch.__version__) >= version.parse("2.8"):
|
|
logger.debug("ComfyUI-GGUF: Allowing full torch compile")
|
|
return dummy_decorator # torch compile works
|
|
if chained_hasattr(torch, "_dynamo.config.nontraceable_tensor_subclasses"):
|
|
logger.debug("ComfyUI-GGUF: Allowing full torch compile (nightly)")
|
|
return dummy_decorator # torch compile works, nightly before 2.8 release
|
|
else:
|
|
logger.debug("ComfyUI-GGUF: Partial torch compile only, consider updating pytorch")
|
|
return torch.compiler.disable
|
|
|
|
|
|
torch_compiler_disable = get_torch_compiler_disable_decorator()
|
|
|
|
|
|
class GGMLTensor(torch.Tensor):
|
|
"""
|
|
Main tensor-like class for storing quantized weights
|
|
"""
|
|
|
|
def __init__(self, *args, tensor_type, tensor_shape, patches=[], **kwargs):
|
|
super().__init__()
|
|
self.tensor_type = tensor_type
|
|
self.tensor_shape = tensor_shape
|
|
self.patches = patches
|
|
|
|
def __new__(cls, *args, tensor_type, tensor_shape, patches=[], **kwargs):
|
|
return super().__new__(cls, *args, **kwargs)
|
|
|
|
def to(self, *args, **kwargs):
|
|
new = super().to(*args, **kwargs)
|
|
new.tensor_type = getattr(self, "tensor_type", None)
|
|
new.tensor_shape = getattr(self, "tensor_shape", new.data.shape)
|
|
new.patches = getattr(self, "patches", []).copy()
|
|
return new
|
|
|
|
def clone(self, *args, **kwargs):
|
|
return self
|
|
|
|
def detach(self, *args, **kwargs):
|
|
return self
|
|
|
|
def copy_(self, *args, **kwargs):
|
|
# fixes .weight.copy_ in comfy/clip_model/CLIPTextModel
|
|
try:
|
|
return super().copy_(*args, **kwargs)
|
|
except Exception as e:
|
|
logger.warning(f"ignoring 'copy_' on tensor: {e}")
|
|
|
|
def new_empty(self, size, *args, **kwargs):
|
|
# Intel Arc fix, ref#50
|
|
new_tensor = super().new_empty(size, *args, **kwargs)
|
|
return GGMLTensor(
|
|
new_tensor,
|
|
tensor_type=getattr(self, "tensor_type", None),
|
|
tensor_shape=size,
|
|
patches=getattr(self, "patches", []).copy()
|
|
)
|
|
|
|
@property
|
|
def shape(self):
|
|
if not hasattr(self, "tensor_shape"):
|
|
self.tensor_shape = self.size()
|
|
return self.tensor_shape
|
|
|
|
|
|
class GGMLLayer(torch.nn.Module):
|
|
"""
|
|
This (should) be responsible for de-quantizing on the fly
|
|
"""
|
|
comfy_cast_weights = True
|
|
dequant_dtype = None
|
|
patch_dtype = None
|
|
largest_layer = False
|
|
torch_compatible_tensor_types = {None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}
|
|
|
|
def is_ggml_quantized(self, *, weight=None, bias=None):
|
|
if weight is None:
|
|
weight = self.weight
|
|
if bias is None:
|
|
bias = self.bias
|
|
return is_quantized(weight) or is_quantized(bias)
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
|
weight, bias = state_dict.get(f"{prefix}weight"), state_dict.get(f"{prefix}bias")
|
|
# NOTE: using modified load for linear due to not initializing on creation, see GGMLOps todo
|
|
if self.is_ggml_quantized(weight=weight, bias=bias) or isinstance(self, torch.nn.Linear):
|
|
return self.ggml_load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
|
# Not strictly required, but fixes embedding shape mismatch. Threshold set in loader.py
|
|
if isinstance(self, torch.nn.Embedding) and self.weight.shape[0] >= (64 * 1024):
|
|
return self.ggml_load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
|
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
|
|
|
def ggml_load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
|
prefix_len = len(prefix)
|
|
for k, v in state_dict.items():
|
|
if k[prefix_len:] == "weight":
|
|
self.weight = torch.nn.Parameter(v, requires_grad=False)
|
|
elif k[prefix_len:] == "bias" and v is not None:
|
|
self.bias = torch.nn.Parameter(v, requires_grad=False)
|
|
else:
|
|
unexpected_keys.append(k)
|
|
|
|
# For Linear layer with missing weight
|
|
if self.weight is None and isinstance(self, torch.nn.Linear):
|
|
v = torch.zeros(self.in_features, self.out_features)
|
|
self.weight = torch.nn.Parameter(v, requires_grad=False)
|
|
missing_keys.append(prefix + "weight")
|
|
|
|
# for vram estimation (TODO: less fragile logic?)
|
|
if getattr(self.weight, "is_largest_weight", False):
|
|
self.largest_layer = True
|
|
|
|
def _save_to_state_dict(self, *args, **kwargs):
|
|
if self.is_ggml_quantized():
|
|
return self.ggml_save_to_state_dict(*args, **kwargs)
|
|
return super()._save_to_state_dict(*args, **kwargs)
|
|
|
|
def ggml_save_to_state_dict(self, destination, prefix, keep_vars):
|
|
# This is a fake state dict for vram estimation
|
|
weight = torch.zeros_like(self.weight, device=torch.device("meta"))
|
|
destination[prefix + "weight"] = weight
|
|
if self.bias is not None:
|
|
bias = torch.zeros_like(self.bias, device=torch.device("meta"))
|
|
destination[prefix + "bias"] = bias
|
|
|
|
# Take into account space required for dequantizing the largest tensor
|
|
if self.largest_layer:
|
|
shape = getattr(self.weight, "tensor_shape", self.weight.shape)
|
|
dtype = self.dequant_dtype if self.dequant_dtype and self.dequant_dtype != "target" else torch.float16
|
|
temp = torch.empty(*shape, device=torch.device("meta"), dtype=dtype)
|
|
destination[prefix + "temp.weight"] = temp
|
|
|
|
return
|
|
|
|
def get_weight(self, tensor, dtype):
|
|
if tensor is None:
|
|
return
|
|
|
|
# consolidate and load patches to GPU in async
|
|
patch_list = []
|
|
device = tensor.device
|
|
for patches, key in getattr(tensor, "patches", []):
|
|
patch_list += move_patch_to_device(patches, device)
|
|
|
|
# dequantize tensor while patches load
|
|
weight = dequantize_tensor(tensor, dtype, self.dequant_dtype)
|
|
|
|
# prevent propagating custom tensor class
|
|
if isinstance(weight, GGMLTensor):
|
|
weight = torch.Tensor(weight)
|
|
|
|
# apply patches
|
|
if len(patch_list) > 0:
|
|
if self.patch_dtype is None:
|
|
weight = calculate_weight(patch_list, weight, key)
|
|
else:
|
|
# for testing, may degrade image quality
|
|
patch_dtype = dtype if self.patch_dtype == "target" else self.patch_dtype
|
|
weight = calculate_weight(patch_list, weight, key, patch_dtype)
|
|
return weight
|
|
|
|
@torch_compiler_disable()
|
|
def cast_bias_weight(self, input=None, dtype=None, device=None, bias_dtype=None):
|
|
if input is not None:
|
|
if dtype is None:
|
|
dtype = getattr(input, "dtype", torch.float32)
|
|
if bias_dtype is None:
|
|
bias_dtype = dtype
|
|
if device is None:
|
|
device = input.device
|
|
|
|
bias = None
|
|
non_blocking = device_supports_non_blocking(device)
|
|
if self.bias is not None:
|
|
bias = self.get_weight(self.bias.to(device), dtype)
|
|
bias = cast_to(bias, bias_dtype, device, non_blocking=non_blocking, copy=False)
|
|
|
|
weight = self.get_weight(self.weight.to(device), dtype)
|
|
weight = cast_to(weight, dtype, device, non_blocking=non_blocking, copy=False)
|
|
return weight, bias
|
|
|
|
def forward_comfy_cast_weights(self, input, *args, **kwargs):
|
|
if self.is_ggml_quantized():
|
|
out = self.forward_ggml_cast_weights(input, *args, **kwargs)
|
|
else:
|
|
# this is from the mixin
|
|
out = super().forward_comfy_cast_weights(input, *args, **kwargs) # pylint: disable=no-member
|
|
|
|
# non-ggml forward might still propagate custom tensor class
|
|
if isinstance(out, GGMLTensor):
|
|
out = torch.Tensor(out)
|
|
return out
|
|
|
|
def forward_ggml_cast_weights(self, input):
|
|
raise NotImplementedError
|
|
|
|
|
|
class GGMLOps(manual_cast):
|
|
"""
|
|
Dequantize weights on the fly before doing the compute
|
|
"""
|
|
|
|
class Linear(GGMLLayer, manual_cast.Linear):
|
|
dequant_dtype = None
|
|
patch_dtype = None
|
|
|
|
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
|
torch.nn.Module.__init__(self)
|
|
# TODO: better workaround for reserved memory spike on windows
|
|
# Issue is with `torch.empty` still reserving the full memory for the layer
|
|
# Windows doesn't over-commit memory so without this 24GB+ of pagefile is used
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
self.weight = None
|
|
self.bias = None
|
|
|
|
def forward_ggml_cast_weights(self, input):
|
|
weight, bias = self.cast_bias_weight(input)
|
|
return torch.nn.functional.linear(input, weight, bias)
|
|
|
|
class Conv2d(GGMLLayer, manual_cast.Conv2d):
|
|
def forward_ggml_cast_weights(self, input):
|
|
weight, bias = self.cast_bias_weight(input)
|
|
return self._conv_forward(input, weight, bias)
|
|
|
|
class Embedding(GGMLLayer, manual_cast.Embedding):
|
|
def forward_ggml_cast_weights(self, input, out_dtype=None):
|
|
output_dtype = out_dtype
|
|
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
|
|
out_dtype = None
|
|
weight, _bias = self.cast_bias_weight(self, device=input.device, dtype=out_dtype)
|
|
return torch.nn.functional.embedding(
|
|
input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse
|
|
).to(dtype=output_dtype)
|
|
|
|
class LayerNorm(GGMLLayer, manual_cast.LayerNorm):
|
|
def forward_ggml_cast_weights(self, input):
|
|
if self.weight is None:
|
|
return super().forward_comfy_cast_weights(input)
|
|
weight, bias = self.cast_bias_weight(input)
|
|
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
|
|
|
class GroupNorm(GGMLLayer, manual_cast.GroupNorm):
|
|
def forward_ggml_cast_weights(self, input):
|
|
weight, bias = self.cast_bias_weight(input)
|
|
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
|
|
|
|
|
def move_patch_to_device(item, device):
|
|
if isinstance(item, torch.Tensor):
|
|
return item.to(device, non_blocking=True)
|
|
elif isinstance(item, tuple):
|
|
return tuple(move_patch_to_device(x, device) for x in item)
|
|
elif isinstance(item, list):
|
|
return [move_patch_to_device(x, device) for x in item]
|
|
else:
|
|
return item
|