ComfyUI/cfz_checkpoint_loader.py
2025-05-28 02:14:31 +03:00

140 lines
5.3 KiB
Python

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.sd import load_checkpoint_guess_config, load_checkpoint
from comfy.model_patcher import ModelPatcher
import folder_paths
# ------------------------ Core Quantization Logic -------------------------
def make_quantized_forward(quant_dtype="float32"):
def forward(self, x):
dtype = torch.float32 if quant_dtype == "float32" else torch.float16
W = self.int8_weight.to(x.device, dtype=dtype)
if hasattr(self, 'zero_point') and self.zero_point is not None:
zp = self.zero_point.to(x.device, dtype=dtype)
W = W.sub(zp).mul(self.scale)
else:
W = W.mul(self.scale)
bias = self.bias.to(dtype) if self.bias is not None else None
# LoRA application (if present)
if hasattr(self, "lora_down") and hasattr(self, "lora_up") and hasattr(self, "lora_alpha"):
x = x + self.lora_up(self.lora_down(x)) * self.lora_alpha
x = x.to(dtype)
if isinstance(self, nn.Linear):
return F.linear(x, W, bias)
elif isinstance(self, nn.Conv2d):
return F.conv2d(x, W, bias,
self.stride, self.padding,
self.dilation, self.groups)
else:
return x
return forward
def quantize_weight(weight: torch.Tensor, num_bits=8, use_asymmetric=False):
reduce_dim = 1 if weight.ndim == 2 else [i for i in range(weight.ndim) if i != 0]
if use_asymmetric:
min_val = weight.amin(dim=reduce_dim, keepdim=True)
max_val = weight.amax(dim=reduce_dim, keepdim=True)
scale = torch.clamp((max_val - min_val) / 255.0, min=1e-8)
zero_point = torch.clamp((-min_val / scale).round(), 0, 255).to(torch.uint8)
qweight = torch.clamp((weight / scale + zero_point).round(), 0, 255).to(torch.uint8)
else:
w_max = weight.abs().amax(dim=reduce_dim, keepdim=True)
scale = torch.clamp(w_max / 127.0, min=1e-8)
qweight = torch.clamp((weight / scale).round(), -128, 127).to(torch.int8)
zero_point = None
return qweight, scale.to(torch.float16), zero_point
def apply_quantization(model, use_asymmetric=False, quant_dtype="float32"):
quant_count = 0
def _quantize_module(module, prefix=""):
nonlocal quant_count
for name, child in module.named_children():
full_name = f"{prefix}.{name}" if prefix else name
if isinstance(child, (nn.Linear, nn.Conv2d)):
try:
W = child.weight.data.float()
qW, scale, zp = quantize_weight(W, use_asymmetric=use_asymmetric)
del child._parameters["weight"]
child.register_buffer("int8_weight", qW)
child.register_buffer("scale", scale)
if zp is not None:
child.register_buffer("zero_point", zp)
else:
child.zero_point = None
child.forward = make_quantized_forward(quant_dtype).__get__(child)
quant_count += 1
except Exception as e:
print(f"Failed to quantize {full_name}: {str(e)}")
_quantize_module(child, full_name)
_quantize_module(model)
print(f"✅ Successfully quantized {quant_count} layers")
return model
# ---------------------- ComfyUI Node Implementation ------------------------
class CheckpointLoaderQuantized:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
"enable_quant": ("BOOLEAN", {"default": True}),
"use_asymmetric": ("BOOLEAN", {"default": False}),
"quant_dtype": (["float32", "float16"], {"default": "float32"}), # Toggle for precision
}
}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_quantized"
CATEGORY = "Loaders (Quantized)"
OUTPUT_NODE = False
def load_quantized(self, ckpt_name, enable_quant, use_asymmetric, quant_dtype):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
if not os.path.exists(ckpt_path):
raise FileNotFoundError(f"Checkpoint {ckpt_name} not found at {ckpt_path}")
model_patcher, clip, vae, _ = load_checkpoint_guess_config(
ckpt_path,
output_vae=True,
output_clip=True,
embedding_directory=folder_paths.get_folder_paths("embeddings")
)
if enable_quant:
mode = "Asymmetric" if use_asymmetric else "Symmetric"
print(f"🔧 Applying {mode} 8-bit quantization to {ckpt_name} (dtype={quant_dtype})")
apply_quantization(model_patcher.model, use_asymmetric=use_asymmetric, quant_dtype=quant_dtype)
else:
print(f"🔧 Loading {ckpt_name} without quantization")
return (model_patcher, clip, vae)
# ------------------------- Node Registration -------------------------------
NODE_CLASS_MAPPINGS = {
"CheckpointLoaderQuantized": CheckpointLoaderQuantized,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CheckpointLoaderQuantized": "CFZ Checkpoint Loader",
}
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']