ComfyUI/comfy/gguf.py
doctorpangloss 7fb748fcef wip merge
2025-12-09 13:22:27 -08:00

1344 lines
50 KiB
Python

"""
Copyright 2025 "City96" and Benjanin Berman
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import argparse
import logging
import os
import warnings
import numpy as np
import re
import gguf
import torch
from safetensors.torch import load_file, save_file
from sentencepiece import sentencepiece_model_pb2 as model
from tqdm import tqdm
from .lora import calculate_weight
from .model_management import device_supports_non_blocking
from .ops import cast_to, manual_cast
logger = logging.getLogger(__name__)
QUANTIZATION_THRESHOLD = 1024
REARRANGE_THRESHOLD = 512
MAX_TENSOR_NAME_LENGTH = 127
MAX_TENSOR_DIMS = 4
TORCH_COMPATIBLE_QTYPES = (None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16)
IMG_ARCH_LIST = {"flux", "sd1", "sdxl", "sd3", "aura", "hidream", "cosmos", "ltxv", "hyvid", "wan", "qwen_image"}
TXT_ARCH_LIST = {"t5", "t5encoder", "llama", "qwen2vl"}
CLIP_VISION_SD_MAP = {
"mm.": "visual.merger.mlp.",
"v.post_ln.": "visual.merger.ln_q.",
"v.patch_embd": "visual.patch_embed.proj",
"v.blk.": "visual.blocks.",
"ffn_up": "mlp.up_proj",
"ffn_down": "mlp.down_proj",
"ffn_gate": "mlp.gate_proj",
"attn_out.": "attn.proj.",
"ln1.": "norm1.",
"ln2.": "norm2.",
}
class ModelTemplate:
arch = "invalid" # string describing architecture
shape_fix = False # whether to reshape tensors
keys_detect = [] # list of lists to match in state dict
keys_banned = [] # list of keys that should mark model as invalid for conversion
keys_hiprec = [] # list of keys that need to be kept in fp32 for some reason
keys_ignore = [] # list of strings to ignore keys by when found
def handle_nd_tensor(self, key, data):
raise NotImplementedError(f"Tensor detected that exceeds dims supported by C++ code! ({key} @ {data.shape})")
class ModelFlux(ModelTemplate):
arch = "flux"
keys_detect = [
("transformer_blocks.0.attn.norm_added_k.weight",),
("double_blocks.0.img_attn.proj.weight",),
]
keys_banned = ["transformer_blocks.0.attn.norm_added_k.weight", ]
class ModelSD3(ModelTemplate):
arch = "sd3"
keys_detect = [
("transformer_blocks.0.attn.add_q_proj.weight",),
("joint_blocks.0.x_block.attn.qkv.weight",),
]
keys_banned = ["transformer_blocks.0.attn.add_q_proj.weight", ]
class ModelAura(ModelTemplate):
arch = "aura"
keys_detect = [
("double_layers.3.modX.1.weight",),
("joint_transformer_blocks.3.ff_context.out_projection.weight",),
]
keys_banned = ["joint_transformer_blocks.3.ff_context.out_projection.weight", ]
class ModelHiDream(ModelTemplate):
arch = "hidream"
keys_detect = [
(
"caption_projection.0.linear.weight",
"double_stream_blocks.0.block.ff_i.shared_experts.w3.weight"
)
]
keys_hiprec = [
# nn.parameter, can't load from BF16 ver
".ff_i.gate.weight",
"img_emb.emb_pos"
]
class CosmosPredict2(ModelTemplate):
arch = "cosmos"
keys_detect = [
(
"blocks.0.mlp.layer1.weight",
"blocks.0.adaln_modulation_cross_attn.1.weight",
)
]
keys_hiprec = ["pos_embedder"]
keys_ignore = ["_extra_state", "accum_"]
class ModelHyVid(ModelTemplate):
arch = "hyvid"
keys_detect = [
(
"double_blocks.0.img_attn_proj.weight",
"txt_in.individual_token_refiner.blocks.1.self_attn_qkv.weight",
)
]
def handle_nd_tensor(self, key, data):
# hacky but don't have any better ideas
path = f"./fix_5d_tensors_{self.arch}.safetensors" # TODO: somehow get a path here??
if os.path.isfile(path):
raise RuntimeError(f"5D tensor fix file already exists! {path}")
fsd = {key: torch.from_numpy(data)}
tqdm.write(f"5D key found in state dict! Manual fix required! - {key} {data.shape}")
save_file(fsd, path)
class ModelWan(ModelHyVid):
arch = "wan"
keys_detect = [
(
"blocks.0.self_attn.norm_q.weight",
"text_embedding.2.weight",
"head.modulation",
)
]
keys_hiprec = [
".modulation" # nn.parameter, can't load from BF16 ver
]
class ModelLTXV(ModelTemplate):
arch = "ltxv"
keys_detect = [
(
"adaln_single.emb.timestep_embedder.linear_2.weight",
"transformer_blocks.27.scale_shift_table",
"caption_projection.linear_2.weight",
)
]
keys_hiprec = [
"scale_shift_table" # nn.parameter, can't load from BF16 base quant
]
class ModelSDXL(ModelTemplate):
arch = "sdxl"
shape_fix = True
keys_detect = [
("down_blocks.0.downsamplers.0.conv.weight", "add_embedding.linear_1.weight",),
(
"input_blocks.3.0.op.weight", "input_blocks.6.0.op.weight",
"output_blocks.2.2.conv.weight", "output_blocks.5.2.conv.weight",
), # Non-diffusers
("label_emb.0.0.weight",),
]
class ModelSD1(ModelTemplate):
arch = "sd1"
shape_fix = True
keys_detect = [
("down_blocks.0.downsamplers.0.conv.weight",),
(
"input_blocks.3.0.op.weight", "input_blocks.6.0.op.weight", "input_blocks.9.0.op.weight",
"output_blocks.2.1.conv.weight", "output_blocks.5.2.conv.weight", "output_blocks.8.2.conv.weight"
), # Non-diffusers
]
# The architectures are checked in order and the first successful match terminates the search.
arch_list = [ModelFlux, ModelSD3, ModelAura, ModelHiDream, CosmosPredict2, ModelLTXV, ModelHyVid, ModelWan, ModelSDXL, ModelSD1]
def is_model_arch(model, state_dict):
# check if model is correct
matched = False
invalid = False
for match_list in model.keys_detect:
if all(key in state_dict for key in match_list):
matched = True
invalid = any(key in state_dict for key in model.keys_banned)
break
assert not invalid, "Model architecture not allowed for conversion! (i.e. reference VS diffusers format)"
return matched
def detect_arch(state_dict):
model_arch = None
for arch in arch_list:
if is_model_arch(arch, state_dict):
model_arch = arch()
break
assert model_arch is not None, "Unknown model architecture!"
return model_arch
def parse_args():
parser = argparse.ArgumentParser(description="Generate F16 GGUF files from single UNET")
parser.add_argument("--src", required=True, help="Source model ckpt file.")
parser.add_argument("--dst", help="Output unet gguf file.")
args = parser.parse_args()
if not os.path.isfile(args.src):
parser.error("No input provided!")
return args
def strip_prefix(state_dict):
# prefix for mixed state dict
prefix = None
for pfx in ["model.diffusion_model.", "model."]:
if any([x.startswith(pfx) for x in state_dict.keys()]):
prefix = pfx
break
# prefix for uniform state dict
if prefix is None:
for pfx in ["net."]:
if all([x.startswith(pfx) for x in state_dict.keys()]):
prefix = pfx
break
# strip prefix if found
if prefix is not None:
logger.info(f"State dict prefix found: '{prefix}'")
sd = {}
for k, v in state_dict.items():
if prefix not in k:
continue
k = k.replace(prefix, "")
sd[k] = v
else:
logger.debug("State dict has no prefix")
sd = state_dict
return sd
def load_state_dict(path):
if any(path.endswith(x) for x in [".ckpt", ".pt", ".bin", ".pth"]):
state_dict = torch.load(path, map_location="cpu", weights_only=True)
for subkey in ["model", "module"]:
if subkey in state_dict:
state_dict = state_dict[subkey]
break
if len(state_dict) < 20:
raise RuntimeError(f"pt subkey load failed: {state_dict.keys()}")
else:
state_dict = load_file(path)
return strip_prefix(state_dict)
def handle_tensors(writer, state_dict, model_arch):
name_lengths = tuple(sorted(
((key, len(key)) for key in state_dict.keys()),
key=lambda item: item[1],
reverse=True,
))
if not name_lengths:
return
max_name_len = name_lengths[0][1]
if max_name_len > MAX_TENSOR_NAME_LENGTH:
bad_list = ", ".join(f"{key!r} ({namelen})" for key, namelen in name_lengths if namelen > MAX_TENSOR_NAME_LENGTH)
raise ValueError(f"Can only handle tensor names up to {MAX_TENSOR_NAME_LENGTH} characters. Tensors exceeding the limit: {bad_list}")
for key, data in tqdm(state_dict.items()):
old_dtype = data.dtype
if any(x in key for x in model_arch.keys_ignore):
tqdm.write(f"Filtering ignored key: '{key}'")
continue
if data.dtype == torch.bfloat16:
data = data.to(torch.float32).numpy()
# this is so we don't break torch 2.0.X
elif data.dtype in [getattr(torch, "float8_e4m3fn", "_invalid"), getattr(torch, "float8_e5m2", "_invalid")]:
data = data.to(torch.float16).numpy()
else:
data = data.numpy()
n_dims = len(data.shape)
data_shape = data.shape
if old_dtype == torch.bfloat16:
data_qtype = gguf.GGMLQuantizationType.BF16
# elif old_dtype == torch.float32:
# data_qtype = gguf.GGMLQuantizationType.F32
else:
data_qtype = gguf.GGMLQuantizationType.F16
# The max no. of dimensions that can be handled by the quantization code is 4
if len(data.shape) > MAX_TENSOR_DIMS:
model_arch.handle_nd_tensor(key, data)
continue # needs to be added back later
# get number of parameters (AKA elements) in this tensor
n_params = 1
for dim_size in data_shape:
n_params *= dim_size
if old_dtype in (torch.float32, torch.bfloat16):
if n_dims == 1:
# one-dimensional tensors should be kept in F32
# also speeds up inference due to not dequantizing
data_qtype = gguf.GGMLQuantizationType.F32
elif n_params <= QUANTIZATION_THRESHOLD:
# very small tensors
data_qtype = gguf.GGMLQuantizationType.F32
elif any(x in key for x in model_arch.keys_hiprec):
# tensors that require max precision
data_qtype = gguf.GGMLQuantizationType.F32
if (model_arch.shape_fix # NEVER reshape for models such as flux
and n_dims > 1 # Skip one-dimensional tensors
and n_params >= REARRANGE_THRESHOLD # Only rearrange tensors meeting the size requirement
and (n_params / 256).is_integer() # Rearranging only makes sense if total elements is divisible by 256
and not (data.shape[-1] / 256).is_integer() # Only need to rearrange if the last dimension is not divisible by 256
):
orig_shape = data.shape
data = data.reshape(n_params // 256, 256)
writer.add_array(f"comfy.gguf.orig_shape.{key}", tuple(int(dim) for dim in orig_shape))
try:
data = gguf.quants.quantize(data, data_qtype)
except (AttributeError, gguf.QuantError) as e:
tqdm.write(f"falling back to F16: {e}")
data_qtype = gguf.GGMLQuantizationType.F16
data = gguf.quants.quantize(data, data_qtype)
new_name = key # do we need to rename?
shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}"
tqdm.write(f"{f'%-{max_name_len + 4}s' % f'{new_name}'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
writer.add_tensor(new_name, data, raw_dtype=data_qtype)
def convert_file(path, dst_path=None, interact=True, overwrite=False):
# load & run model detection logic
state_dict = load_state_dict(path)
model_arch = detect_arch(state_dict)
logger.info(f"* Architecture detected from input: {model_arch.arch}")
# detect & set dtype for output file
dtypes = [x.dtype for x in state_dict.values()]
dtypes = {x: dtypes.count(x) for x in set(dtypes)}
main_dtype = max(dtypes, key=dtypes.get)
if main_dtype == torch.bfloat16:
ftype_name = "BF16"
ftype_gguf = gguf.LlamaFileType.MOSTLY_BF16
# elif main_dtype == torch.float32:
# ftype_name = "F32"
# ftype_gguf = None
else:
ftype_name = "F16"
ftype_gguf = gguf.LlamaFileType.MOSTLY_F16
if dst_path is None:
dst_path = f"{os.path.splitext(path)[0]}-{ftype_name}.gguf"
elif "{ftype}" in dst_path: # lcpp logic
dst_path = dst_path.replace("{ftype}", ftype_name)
if os.path.isfile(dst_path) and not overwrite:
if interact:
input("Output exists enter to continue or ctrl+c to abort!")
else:
raise OSError("Output exists and overwriting is disabled!")
# handle actual file
writer = gguf.GGUFWriter(path=None, arch=model_arch.arch)
writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
if ftype_gguf is not None:
writer.add_file_type(ftype_gguf)
handle_tensors(writer, state_dict, model_arch)
writer.write_header_to_file(path=dst_path)
writer.write_kv_data_to_file()
writer.write_tensors_to_file(progress=True)
writer.close()
fix = f"./fix_5d_tensors_{model_arch.arch}.safetensors"
if os.path.isfile(fix):
logger.warning(f"\n### Warning! Fix file found at '{fix}'")
logger.warning(" you most likely need to run 'fix_5d_tensors.py' after quantization.")
return dst_path, model_arch
def is_torch_compatible(tensor):
return tensor is None or getattr(tensor, "tensor_type", None) in TORCH_COMPATIBLE_QTYPES
def is_quantized(tensor):
return not is_torch_compatible(tensor)
def dequantize_tensor(tensor, dtype=None, dequant_dtype=None):
qtype = getattr(tensor, "tensor_type", None)
oshape = getattr(tensor, "tensor_shape", tensor.shape)
if qtype in TORCH_COMPATIBLE_QTYPES:
return tensor.to(dtype)
elif qtype in dequantize_functions:
dequant_dtype = dtype if dequant_dtype == "target" else dequant_dtype
return dequantize(tensor.data, qtype, oshape, dtype=dequant_dtype).to(dtype)
else:
# this is incredibly slow
tqdm.write(f"Falling back to numpy dequant for qtype: {getattr(qtype, 'name', repr(qtype))}")
new = gguf.quants.dequantize(tensor.cpu().numpy(), qtype)
return torch.from_numpy(new).to(tensor.device, dtype=dtype)
def dequantize(data, qtype, oshape, dtype=None):
"""
Dequantize tensor back to usable shape/dtype
"""
block_size, type_size = gguf.GGML_QUANT_SIZES[qtype]
dequantize_blocks = dequantize_functions[qtype]
rows = data.reshape(
(-1, data.shape[-1])
).view(torch.uint8)
n_blocks = rows.numel() // type_size
blocks = rows.reshape((n_blocks, type_size))
blocks = dequantize_blocks(blocks, block_size, type_size, dtype)
return blocks.reshape(oshape)
def to_uint32(x):
# no uint32 :(
x = x.view(torch.uint8).to(torch.int32)
return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1)
def split_block_dims(blocks, *args):
n_max = blocks.shape[1]
dims = list(args) + [n_max - sum(args)]
return torch.split(blocks, dims, dim=1)
# Full weights #
def dequantize_blocks_BF16(blocks, block_size, type_size, dtype=None):
return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32)
# Legacy Quants #
def dequantize_blocks_Q8_0(blocks, block_size, type_size, dtype=None):
d, x = split_block_dims(blocks, 2)
d = d.view(torch.float16).to(dtype)
x = x.view(torch.int8)
return (d * x)
def dequantize_blocks_Q5_1(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
d, m, qh, qs = split_block_dims(blocks, 2, 2, 4)
d = d.view(torch.float16).to(dtype)
m = m.view(torch.float16).to(dtype)
qh = to_uint32(qh)
qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1)
qh = (qh & 1).to(torch.uint8)
ql = (ql & 0x0F).reshape((n_blocks, -1))
qs = (ql | (qh << 4))
return (d * qs) + m
def dequantize_blocks_Q5_0(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
d, qh, qs = split_block_dims(blocks, 2, 4)
d = d.view(torch.float16).to(dtype)
qh = to_uint32(qh)
qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1)
qh = (qh & 1).to(torch.uint8)
ql = (ql & 0x0F).reshape(n_blocks, -1)
qs = (ql | (qh << 4)).to(torch.int8) - 16
return (d * qs)
def dequantize_blocks_Q4_1(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
d, m, qs = split_block_dims(blocks, 2, 2)
d = d.view(torch.float16).to(dtype)
m = m.view(torch.float16).to(dtype)
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1)
qs = (qs & 0x0F).reshape(n_blocks, -1)
return (d * qs) + m
def dequantize_blocks_Q4_0(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
d, qs = split_block_dims(blocks, 2)
d = d.view(torch.float16).to(dtype)
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8
return (d * qs)
# K Quants #
QK_K = 256
K_SCALE_SIZE = 12
def get_scale_min(scales):
n_blocks = scales.shape[0]
scales = scales.view(torch.uint8)
scales = scales.reshape((n_blocks, 3, 4))
d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2)
sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1)
min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1)
return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8)))
def dequantize_blocks_Q6_K(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
ql, qh, scales, d, = split_block_dims(blocks, QK_K // 2, QK_K // 4, QK_K // 16)
scales = scales.view(torch.int8).to(dtype)
d = d.view(torch.float16).to(dtype)
d = (d * scales).reshape((n_blocks, QK_K // 16, 1))
ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1))
qh = (qh & 0x03).reshape((n_blocks, -1, 32))
q = (ql | (qh << 4)).to(torch.int8) - 32
q = q.reshape((n_blocks, QK_K // 16, -1))
return (d * q).reshape((n_blocks, QK_K))
def dequantize_blocks_Q5_K(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
d, dmin, scales, qh, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE, QK_K // 8)
d = d.view(torch.float16).to(dtype)
dmin = dmin.view(torch.float16).to(dtype)
sc, m = get_scale_min(scales)
d = (d * sc).reshape((n_blocks, -1, 1))
dm = (dmin * m).reshape((n_blocks, -1, 1))
ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([i for i in range(8)], device=d.device, dtype=torch.uint8).reshape((1, 1, 8, 1))
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
qh = (qh & 0x01).reshape((n_blocks, -1, 32))
q = (ql | (qh << 4))
return (d * q - dm).reshape((n_blocks, QK_K))
def dequantize_blocks_Q4_K(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE)
d = d.view(torch.float16).to(dtype)
dmin = dmin.view(torch.float16).to(dtype)
sc, m = get_scale_min(scales)
d = (d * sc).reshape((n_blocks, -1, 1))
dm = (dmin * m).reshape((n_blocks, -1, 1))
qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
qs = (qs & 0x0F).reshape((n_blocks, -1, 32))
return (d * qs - dm).reshape((n_blocks, QK_K))
def dequantize_blocks_Q3_K(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
hmask, qs, scales, d = split_block_dims(blocks, QK_K // 8, QK_K // 4, 12)
d = d.view(torch.float16).to(dtype)
lscales, hscales = scales[:, :8], scales[:, 8:]
lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 2, 1))
lscales = lscales.reshape((n_blocks, 16))
hscales = hscales.reshape((n_blocks, 1, 4)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 4, 1))
hscales = hscales.reshape((n_blocks, 16))
scales = (lscales & 0x0F) | ((hscales & 0x03) << 4)
scales = (scales.to(torch.int8) - 32)
dl = (d * scales).reshape((n_blocks, 16, 1))
ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1))
qh = hmask.reshape(n_blocks, -1, 1, 32) >> torch.tensor([i for i in range(8)], device=d.device, dtype=torch.uint8).reshape((1, 1, 8, 1))
ql = ql.reshape((n_blocks, 16, QK_K // 16)) & 3
qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & 1) ^ 1
q = (ql.to(torch.int8) - (qh << 2).to(torch.int8))
return (dl * q).reshape((n_blocks, QK_K))
def dequantize_blocks_Q2_K(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
scales, qs, d, dmin = split_block_dims(blocks, QK_K // 16, QK_K // 4, 2)
d = d.view(torch.float16).to(dtype)
dmin = dmin.view(torch.float16).to(dtype)
# (n_blocks, 16, 1)
dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1))
ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1))
shift = torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1))
qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & 3
qs = qs.reshape((n_blocks, QK_K // 16, 16))
qs = dl * qs - ml
return qs.reshape((n_blocks, -1))
dequantize_functions = {
gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16,
gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0,
gguf.GGMLQuantizationType.Q5_1: dequantize_blocks_Q5_1,
gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0,
gguf.GGMLQuantizationType.Q4_1: dequantize_blocks_Q4_1,
gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0,
gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K,
gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K,
gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K,
gguf.GGMLQuantizationType.Q3_K: dequantize_blocks_Q3_K,
gguf.GGMLQuantizationType.Q2_K: dequantize_blocks_Q2_K,
}
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
def get_orig_shape(reader, tensor_name):
field_key = f"comfy.gguf.orig_shape.{tensor_name}"
field = reader.get_field(field_key)
if field is None:
return None
# Has original shape metadata, so we try to decode it.
if len(field.types) != 2 or field.types[0] != gguf.GGUFValueType.ARRAY or field.types[1] != gguf.GGUFValueType.INT32:
raise TypeError(f"Bad original shape metadata for {field_key}: Expected ARRAY of INT32, got {field.types}")
return torch.Size(tuple(int(field.parts[part_idx][0]) for part_idx in field.data))
def get_field(reader, field_name, field_type):
field = reader.get_field(field_name)
if field is None:
return None
elif field_type == str:
# extra check here as this is used for checking arch string
if len(field.types) != 1 or field.types[0] != gguf.GGUFValueType.STRING:
raise TypeError(f"Bad type for GGUF {field_name} key: expected string, got {field.types!r}")
return str(field.parts[field.data[-1]], encoding="utf-8")
elif field_type in [int, float, bool]:
return field_type(field.parts[field.data[-1]])
else:
raise TypeError(f"Unknown field type {field_type}")
def get_list_field(reader, field_name, field_type):
field = reader.get_field(field_name)
if field is None:
return None
elif field_type == str:
return tuple(str(field.parts[part_idx], encoding="utf-8") for part_idx in field.data)
elif field_type in [int, float, bool]:
return tuple(field_type(field.parts[part_idx][0]) for part_idx in field.data)
else:
raise TypeError(f"Unknown field type {field_type}")
def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=False, is_text_model=False) -> dict:
"""
Read state dict as fake tensors
"""
reader = gguf.GGUFReader(path)
# filter and strip prefix
has_prefix = False
if handle_prefix is not None:
prefix_len = len(handle_prefix)
tensor_names = set(tensor.name for tensor in reader.tensors)
has_prefix = any(s.startswith(handle_prefix) for s in tensor_names)
else:
prefix_len = 0
tensors = []
for tensor in reader.tensors:
sd_key = tensor_name = tensor.name
if has_prefix:
if not tensor_name.startswith(handle_prefix):
continue
sd_key = tensor_name[prefix_len:]
tensors.append((sd_key, tensor))
# detect and verify architecture
compat = None
arch_str = get_field(reader, "general.architecture", str)
if arch_str in [None, "pig"]:
if is_text_model:
raise ValueError(f"This text model is incompatible with llama.cpp!\nConsider using the safetensors version\n({path})")
compat = "sd.cpp" if arch_str is None else arch_str
# import here to avoid changes to convert.py breaking regular models
try:
arch_str = detect_arch(set(val[0] for val in tensors)).arch
except Exception as e:
raise ValueError(f"This model is not currently supported - ({e})")
elif arch_str not in TXT_ARCH_LIST and is_text_model:
logger.warning(f"Unexpected text model architecture type in GGUF file: {arch_str!r}")
elif arch_str not in IMG_ARCH_LIST and not is_text_model:
logger.warning(f"Unexpected architecture type in GGUF file: {arch_str!r}")
if compat:
logger.warning(f"Warning: This gguf model file is loaded in compatibility mode '{compat}' [arch:{arch_str}]")
# main loading loop
state_dict = {}
qtype_dict = {}
for sd_key, tensor in tensors:
tensor_name = tensor.name
# torch_tensor = torch.from_numpy(tensor.data) # mmap
# NOTE: line above replaced with this block to avoid persistent numpy warning about mmap
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="The given NumPy array is not writable")
torch_tensor = torch.from_numpy(tensor.data) # mmap
shape = get_orig_shape(reader, tensor_name)
if shape is None:
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
# Workaround for stable-diffusion.cpp SDXL detection.
if compat == "sd.cpp" and arch_str == "sdxl":
if any([tensor_name.endswith(x) for x in (".proj_in.weight", ".proj_out.weight")]):
while len(shape) > 2 and shape[-1] == 1:
shape = shape[:-1]
# add to state dict
if tensor.tensor_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}:
torch_tensor = torch_tensor.view(*shape)
state_dict[sd_key] = GGMLTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape)
# keep track of loaded tensor types
tensor_type_str = getattr(tensor.tensor_type, "name", repr(tensor.tensor_type))
qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1
# print loaded tensor type counts
logger.debug("gguf qtypes: " + ", ".join(f"{k} ({v})" for k, v in qtype_dict.items()))
# mark largest tensor for vram estimation
qsd = {k: v for k, v in state_dict.items() if is_quantized(v)}
if len(qsd) > 0:
max_key = max(qsd.keys(), key=lambda k: qsd[k].numel())
state_dict[max_key].is_largest_weight = True
if return_arch:
return (state_dict, arch_str)
return state_dict
# for remapping llama.cpp -> original key names
T5_SD_MAP = {
"enc.": "encoder.",
".blk.": ".block.",
"token_embd": "shared",
"output_norm": "final_layer_norm",
"attn_q": "layer.0.SelfAttention.q",
"attn_k": "layer.0.SelfAttention.k",
"attn_v": "layer.0.SelfAttention.v",
"attn_o": "layer.0.SelfAttention.o",
"attn_norm": "layer.0.layer_norm",
"attn_rel_b": "layer.0.SelfAttention.relative_attention_bias",
"ffn_up": "layer.1.DenseReluDense.wi_1",
"ffn_down": "layer.1.DenseReluDense.wo",
"ffn_gate": "layer.1.DenseReluDense.wi_0",
"ffn_norm": "layer.1.layer_norm",
}
LLAMA_SD_MAP = {
"blk.": "model.layers.",
"attn_norm": "input_layernorm",
"attn_q": "self_attn.q_proj",
"attn_k": "self_attn.k_proj",
"attn_v": "self_attn.v_proj",
"attn_output": "self_attn.o_proj",
"ffn_up": "mlp.up_proj",
"ffn_down": "mlp.down_proj",
"ffn_gate": "mlp.gate_proj",
"ffn_norm": "post_attention_layernorm",
"token_embd": "model.embed_tokens",
"output_norm": "model.norm",
"output.weight": "lm_head.weight",
}
def sd_map_replace(raw_sd, key_map):
sd = {}
for k, v in raw_sd.items():
for s, d in key_map.items():
k = k.replace(s, d)
sd[k] = v
return sd
def llama_permute(raw_sd, n_head, n_head_kv):
# Reverse version of LlamaModel.permute in llama.cpp convert script
sd = {}
permute = lambda x, h: x.reshape(h, x.shape[0] // h // 2, 2, *x.shape[1:]).swapaxes(1, 2).reshape(x.shape)
for k, v in raw_sd.items():
if k.endswith(("q_proj.weight", "q_proj.bias")):
v.data = permute(v.data, n_head)
if k.endswith(("k_proj.weight", "k_proj.bias")):
v.data = permute(v.data, n_head_kv)
sd[k] = v
return sd
def gguf_tokenizer_loader(path, temb_shape):
# convert gguf tokenizer to spiece
logger.info("Attempting to recreate sentencepiece tokenizer from GGUF file metadata...")
spm = model.ModelProto()
reader = gguf.GGUFReader(path)
if get_field(reader, "tokenizer.ggml.model", str) == "t5":
if temb_shape == (256384, 4096): # probably UMT5
spm.trainer_spec.model_type == 1 # Unigram (do we have a T5 w/ BPE?)
else:
raise NotImplementedError("Unknown model, can't set tokenizer!")
else:
raise NotImplementedError("Unknown model, can't set tokenizer!")
spm.normalizer_spec.add_dummy_prefix = get_field(reader, "tokenizer.ggml.add_space_prefix", bool)
spm.normalizer_spec.remove_extra_whitespaces = get_field(reader, "tokenizer.ggml.remove_extra_whitespaces", bool)
tokens = get_list_field(reader, "tokenizer.ggml.tokens", str)
scores = get_list_field(reader, "tokenizer.ggml.scores", float)
toktypes = get_list_field(reader, "tokenizer.ggml.token_type", int)
for idx, (token, score, toktype) in enumerate(zip(tokens, scores, toktypes)):
# # These aren't present in the original?
# if toktype == 5 and idx >= temb_shape[0]%1000):
# continue
piece = spm.SentencePiece()
piece.piece = token
piece.score = score
piece.type = toktype
spm.pieces.append(piece)
# unsure if any of these are correct
spm.trainer_spec.byte_fallback = True
spm.trainer_spec.vocab_size = len(tokens) # split off unused?
spm.trainer_spec.max_sentence_length = 4096
spm.trainer_spec.eos_id = get_field(reader, "tokenizer.ggml.eos_token_id", int)
spm.trainer_spec.pad_id = get_field(reader, "tokenizer.ggml.padding_token_id", int)
logger.info(f"Created tokenizer with vocab size of {len(spm.pieces)}")
del reader
return torch.ByteTensor(list(spm.SerializeToString()))
def strip_quant_suffix(name):
pattern = r"[-_]?(?:ud-)?i?q[0-9]_[a-z0-9_\-]{1,8}$"
match = re.search(pattern, name, re.IGNORECASE)
if match:
name = name[:match.start()]
return name
def gguf_mmproj_loader(path):
# Reverse version of Qwen2VLVisionModel.modify_tensors
logger.info("Attempting to find mmproj file for text encoder...")
# get name to match w/o quant suffix
tenc_fname = os.path.basename(path)
tenc = os.path.splitext(tenc_fname)[0].lower()
tenc = strip_quant_suffix(tenc)
# try and find matching mmproj
target = []
root = os.path.dirname(path)
for fname in os.listdir(root):
name, ext = os.path.splitext(fname)
if ext.lower() != ".gguf":
continue
if "mmproj" not in name.lower():
continue
if tenc in name.lower():
target.append(fname)
if len(target) == 0:
logger.error(f"Error: Can't find mmproj file for '{tenc_fname}' (matching:'{tenc}')! Qwen-Image-Edit will be broken!")
return {}
if len(target) > 1:
logger.error(f"Ambiguous mmproj for text encoder '{tenc_fname}', will use first match.")
logger.info(f"Using mmproj '{target[0]}' for text encoder '{tenc_fname}'.")
target = os.path.join(root, target[0])
vsd = gguf_sd_loader(target, is_text_model=True)
# concat 4D to 5D
if "v.patch_embd.weight.1" in vsd:
w1 = dequantize_tensor(vsd.pop("v.patch_embd.weight"), dtype=torch.float32)
w2 = dequantize_tensor(vsd.pop("v.patch_embd.weight.1"), dtype=torch.float32)
vsd["v.patch_embd.weight"] = torch.stack([w1, w2], dim=2)
# run main replacement
vsd = sd_map_replace(vsd, CLIP_VISION_SD_MAP)
# handle split Q/K/V
if "visual.blocks.0.attn_q.weight" in vsd:
attns = {}
# filter out attentions + group
for k,v in vsd.items():
if any(x in k for x in ["attn_q", "attn_k", "attn_v"]):
k_attn, k_name = k.rsplit(".attn_", 1)
k_attn += ".attn.qkv." + k_name.split(".")[-1]
if k_attn not in attns:
attns[k_attn] = {}
attns[k_attn][k_name] = dequantize_tensor(
v, dtype=(torch.bfloat16 if is_quantized(v) else torch.float16)
)
# recombine
for k,v in attns.items():
suffix = k.split(".")[-1]
vsd[k] = torch.cat([
v[f"q.{suffix}"],
v[f"k.{suffix}"],
v[f"v.{suffix}"],
], dim=0)
del attns
return vsd
def gguf_tekken_tokenizer_loader(path, temb_shape):
# convert ggml (hf) tokenizer metadata to tekken/comfy data
logger.info("Attempting to recreate tekken tokenizer from GGUF file metadata...")
import json
import base64
from transformers.convert_slow_tokenizer import bytes_to_unicode
reader = gguf.GGUFReader(path)
model_str = get_field(reader, "tokenizer.ggml.model", str)
if model_str == "gpt2":
if temb_shape == (131072, 5120): # probably Mistral
data = {
"config": {"num_vocab_tokens": 150000, "default_vocab_size": 131072},
"vocab": [],
"special_tokens": [],
}
else:
raise NotImplementedError("Unknown model, can't set tokenizer!")
else:
raise NotImplementedError("Unknown model, can't set tokenizer!")
tokens = get_list_field(reader, "tokenizer.ggml.tokens", str)
toktypes = get_list_field(reader, "tokenizer.ggml.token_type", int)
decoder = {v: k for k, v in bytes_to_unicode().items()}
for idx, (token, toktype) in enumerate(zip(tokens, toktypes)):
if toktype == 3:
data["special_tokens"].append(
{'rank': idx, 'token_str': token, 'is_control': True}
)
else:
tok = bytes([decoder[char] for char in token])
data["vocab"].append({
"rank": len(data["vocab"]),
"token_bytes": base64.b64encode(tok).decode("ascii"),
"token_str": tok.decode("utf-8", errors="replace") # ?
})
logger.info(f"Created tekken tokenizer with vocab size of {len(data['vocab'])} (+{len(data['special_tokens'])})")
del reader
return torch.ByteTensor(list(json.dumps(data).encode('utf-8')))
def gguf_clip_loader(path):
sd, arch = gguf_sd_loader(path, return_arch=True, is_text_model=True)
if arch in {"t5", "t5encoder"}:
temb_key = "token_embd.weight"
if temb_key in sd and sd[temb_key].shape == (256384, 4096):
# non-standard Comfy-Org tokenizer
sd["spiece_model"] = gguf_tokenizer_loader(path, sd[temb_key].shape)
# TODO: dequantizing token embed here is janky but otherwise we OOM due to tensor being massive.
logger.warning(f"Dequantizing {temb_key} to prevent runtime OOM.")
sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16)
sd = sd_map_replace(sd, T5_SD_MAP)
elif arch in {"llama", "qwen2vl"}:
# TODO: pass model_options["vocab_size"] to loader somehow
temb_key = "token_embd.weight"
if temb_key in sd and sd[temb_key].shape[0] >= (64 * 1024):
if arch == "llama" and sd[temb_key].shape == (131072, 5120):
# non-standard Comfy-Org tokenizer
sd["tekken_model"] = gguf_tekken_tokenizer_loader(path, sd[temb_key].shape)
# See note above for T5.
logger.warning(f"Dequantizing {temb_key} to prevent runtime OOM.")
sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16)
sd = sd_map_replace(sd, LLAMA_SD_MAP)
if arch == "llama":
sd = llama_permute(sd, 32, 8) # L3 / Mistral
if arch == "qwen2vl":
vsd = gguf_mmproj_loader(path)
sd.update(vsd)
else:
pass
return sd
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
def chained_hasattr(obj, chained_attr):
probe = obj
for attr in chained_attr.split('.'):
if hasattr(probe, attr):
probe = getattr(probe, attr)
else:
return False
return True
# A bakcward and forward compatible way to get `torch.compiler.disable`.
def get_torch_compiler_disable_decorator():
def dummy_decorator(*args, **kwargs):
def noop(x):
return x
return noop
from packaging import version
if not chained_hasattr(torch, "compiler.disable"):
logger.debug("ComfyUI-GGUF: Torch too old for torch.compile - bypassing")
return dummy_decorator # torch too old
elif version.parse(torch.__version__) >= version.parse("2.8"):
logger.debug("ComfyUI-GGUF: Allowing full torch compile")
return dummy_decorator # torch compile works
if chained_hasattr(torch, "_dynamo.config.nontraceable_tensor_subclasses"):
logger.debug("ComfyUI-GGUF: Allowing full torch compile (nightly)")
return dummy_decorator # torch compile works, nightly before 2.8 release
else:
logger.debug("ComfyUI-GGUF: Partial torch compile only, consider updating pytorch")
return torch.compiler.disable
torch_compiler_disable = get_torch_compiler_disable_decorator()
class GGMLTensor(torch.Tensor):
"""
Main tensor-like class for storing quantized weights
"""
def __init__(self, *args, tensor_type, tensor_shape, patches=[], **kwargs):
super().__init__()
self.tensor_type = tensor_type
self.tensor_shape = tensor_shape
self.patches = patches
def __new__(cls, *args, tensor_type, tensor_shape, patches=[], **kwargs):
return super().__new__(cls, *args, **kwargs)
def to(self, *args, **kwargs):
new = super().to(*args, **kwargs)
new.tensor_type = getattr(self, "tensor_type", None)
new.tensor_shape = getattr(self, "tensor_shape", new.data.shape)
new.patches = getattr(self, "patches", []).copy()
return new
def clone(self, *args, **kwargs):
return self
def detach(self, *args, **kwargs):
return self
def copy_(self, *args, **kwargs):
# fixes .weight.copy_ in comfy/clip_model/CLIPTextModel
try:
return super().copy_(*args, **kwargs)
except Exception as e:
logger.warning(f"ignoring 'copy_' on tensor: {e}")
def new_empty(self, size, *args, **kwargs):
# Intel Arc fix, ref#50
new_tensor = super().new_empty(size, *args, **kwargs)
return GGMLTensor(
new_tensor,
tensor_type=getattr(self, "tensor_type", None),
tensor_shape=size,
patches=getattr(self, "patches", []).copy()
)
@property
def shape(self):
if not hasattr(self, "tensor_shape"):
self.tensor_shape = self.size()
return self.tensor_shape
class GGMLLayer(torch.nn.Module):
"""
This (should) be responsible for de-quantizing on the fly
"""
comfy_cast_weights = True
dequant_dtype = None
patch_dtype = None
largest_layer = False
torch_compatible_tensor_types = {None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}
def is_ggml_quantized(self, *, weight=None, bias=None):
if weight is None:
weight = self.weight
if bias is None:
bias = self.bias
return is_quantized(weight) or is_quantized(bias)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
weight, bias = state_dict.get(f"{prefix}weight"), state_dict.get(f"{prefix}bias")
# NOTE: using modified load for linear due to not initializing on creation, see GGMLOps todo
if self.is_ggml_quantized(weight=weight, bias=bias) or isinstance(self, torch.nn.Linear):
return self.ggml_load_from_state_dict(state_dict, prefix, *args, **kwargs)
# Not strictly required, but fixes embedding shape mismatch. Threshold set in loader.py
if isinstance(self, torch.nn.Embedding) and self.weight.shape[0] >= (64 * 1024):
return self.ggml_load_from_state_dict(state_dict, prefix, *args, **kwargs)
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
def ggml_load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
prefix_len = len(prefix)
for k, v in state_dict.items():
if k[prefix_len:] == "weight":
self.weight = torch.nn.Parameter(v, requires_grad=False)
elif k[prefix_len:] == "bias" and v is not None:
self.bias = torch.nn.Parameter(v, requires_grad=False)
else:
unexpected_keys.append(k)
# For Linear layer with missing weight
if self.weight is None and isinstance(self, torch.nn.Linear):
v = torch.zeros(self.in_features, self.out_features)
self.weight = torch.nn.Parameter(v, requires_grad=False)
missing_keys.append(prefix + "weight")
# for vram estimation (TODO: less fragile logic?)
if getattr(self.weight, "is_largest_weight", False):
self.largest_layer = True
def _save_to_state_dict(self, *args, **kwargs):
if self.is_ggml_quantized():
return self.ggml_save_to_state_dict(*args, **kwargs)
return super()._save_to_state_dict(*args, **kwargs)
def ggml_save_to_state_dict(self, destination, prefix, keep_vars):
# This is a fake state dict for vram estimation
weight = torch.zeros_like(self.weight, device=torch.device("meta"))
destination[prefix + "weight"] = weight
if self.bias is not None:
bias = torch.zeros_like(self.bias, device=torch.device("meta"))
destination[prefix + "bias"] = bias
# Take into account space required for dequantizing the largest tensor
if self.largest_layer:
shape = getattr(self.weight, "tensor_shape", self.weight.shape)
dtype = self.dequant_dtype if self.dequant_dtype and self.dequant_dtype != "target" else torch.float16
temp = torch.empty(*shape, device=torch.device("meta"), dtype=dtype)
destination[prefix + "temp.weight"] = temp
return
def get_weight(self, tensor, dtype):
if tensor is None:
return
# consolidate and load patches to GPU in async
patch_list = []
device = tensor.device
for patches, key in getattr(tensor, "patches", []):
patch_list += move_patch_to_device(patches, device)
# dequantize tensor while patches load
weight = dequantize_tensor(tensor, dtype, self.dequant_dtype)
# prevent propagating custom tensor class
if isinstance(weight, GGMLTensor):
weight = torch.Tensor(weight)
# apply patches
if len(patch_list) > 0:
if self.patch_dtype is None:
weight = calculate_weight(patch_list, weight, key)
else:
# for testing, may degrade image quality
patch_dtype = dtype if self.patch_dtype == "target" else self.patch_dtype
weight = calculate_weight(patch_list, weight, key, patch_dtype)
return weight
@torch_compiler_disable()
def cast_bias_weight(self, input=None, dtype=None, device=None, bias_dtype=None):
if input is not None:
if dtype is None:
dtype = getattr(input, "dtype", torch.float32)
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device
bias = None
non_blocking = device_supports_non_blocking(device)
if self.bias is not None:
bias = self.get_weight(self.bias.to(device), dtype)
bias = cast_to(bias, bias_dtype, device, non_blocking=non_blocking, copy=False)
weight = self.get_weight(self.weight.to(device), dtype)
weight = cast_to(weight, dtype, device, non_blocking=non_blocking, copy=False)
return weight, bias
def forward_comfy_cast_weights(self, input, *args, **kwargs):
if self.is_ggml_quantized():
out = self.forward_ggml_cast_weights(input, *args, **kwargs)
else:
# this is from the mixin
out = super().forward_comfy_cast_weights(input, *args, **kwargs) # pylint: disable=no-member
# non-ggml forward might still propagate custom tensor class
if isinstance(out, GGMLTensor):
out = torch.Tensor(out)
return out
def forward_ggml_cast_weights(self, input):
raise NotImplementedError
class GGMLOps(manual_cast):
"""
Dequantize weights on the fly before doing the compute
"""
class Linear(GGMLLayer, manual_cast.Linear):
dequant_dtype = None
patch_dtype = None
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
torch.nn.Module.__init__(self)
# TODO: better workaround for reserved memory spike on windows
# Issue is with `torch.empty` still reserving the full memory for the layer
# Windows doesn't over-commit memory so without this 24GB+ of pagefile is used
self.in_features = in_features
self.out_features = out_features
self.weight = None
self.bias = None
def forward_ggml_cast_weights(self, input):
weight, bias = self.cast_bias_weight(input)
return torch.nn.functional.linear(input, weight, bias)
class Conv2d(GGMLLayer, manual_cast.Conv2d):
def forward_ggml_cast_weights(self, input):
weight, bias = self.cast_bias_weight(input)
return self._conv_forward(input, weight, bias)
class Embedding(GGMLLayer, manual_cast.Embedding):
def forward_ggml_cast_weights(self, input, out_dtype=None):
output_dtype = out_dtype
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
out_dtype = None
weight, _bias = self.cast_bias_weight(self, device=input.device, dtype=out_dtype)
return torch.nn.functional.embedding(
input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse
).to(dtype=output_dtype)
class LayerNorm(GGMLLayer, manual_cast.LayerNorm):
def forward_ggml_cast_weights(self, input):
if self.weight is None:
return super().forward_comfy_cast_weights(input)
weight, bias = self.cast_bias_weight(input)
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
class GroupNorm(GGMLLayer, manual_cast.GroupNorm):
def forward_ggml_cast_weights(self, input):
weight, bias = self.cast_bias_weight(input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def move_patch_to_device(item, device):
if isinstance(item, torch.Tensor):
return item.to(device, non_blocking=True)
elif isinstance(item, tuple):
return tuple(move_patch_to_device(x, device) for x in item)
elif isinstance(item, list):
return [move_patch_to_device(x, device) for x in item]
else:
return item