mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-09 05:40:49 +08:00
140 lines
5.3 KiB
Python
140 lines
5.3 KiB
Python
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from comfy.sd import load_checkpoint_guess_config, load_checkpoint
|
|
from comfy.model_patcher import ModelPatcher
|
|
import folder_paths
|
|
|
|
# ------------------------ Core Quantization Logic -------------------------
|
|
def make_quantized_forward(quant_dtype="float32"):
|
|
def forward(self, x):
|
|
dtype = torch.float32 if quant_dtype == "float32" else torch.float16
|
|
|
|
W = self.int8_weight.to(x.device, dtype=dtype)
|
|
if hasattr(self, 'zero_point') and self.zero_point is not None:
|
|
zp = self.zero_point.to(x.device, dtype=dtype)
|
|
W = W.sub(zp).mul(self.scale)
|
|
else:
|
|
W = W.mul(self.scale)
|
|
|
|
bias = self.bias.to(dtype) if self.bias is not None else None
|
|
|
|
# LoRA application (if present)
|
|
if hasattr(self, "lora_down") and hasattr(self, "lora_up") and hasattr(self, "lora_alpha"):
|
|
x = x + self.lora_up(self.lora_down(x)) * self.lora_alpha
|
|
|
|
x = x.to(dtype)
|
|
|
|
if isinstance(self, nn.Linear):
|
|
return F.linear(x, W, bias)
|
|
elif isinstance(self, nn.Conv2d):
|
|
return F.conv2d(x, W, bias,
|
|
self.stride, self.padding,
|
|
self.dilation, self.groups)
|
|
else:
|
|
return x
|
|
return forward
|
|
|
|
|
|
|
|
def quantize_weight(weight: torch.Tensor, num_bits=8, use_asymmetric=False):
|
|
reduce_dim = 1 if weight.ndim == 2 else [i for i in range(weight.ndim) if i != 0]
|
|
if use_asymmetric:
|
|
min_val = weight.amin(dim=reduce_dim, keepdim=True)
|
|
max_val = weight.amax(dim=reduce_dim, keepdim=True)
|
|
scale = torch.clamp((max_val - min_val) / 255.0, min=1e-8)
|
|
zero_point = torch.clamp((-min_val / scale).round(), 0, 255).to(torch.uint8)
|
|
qweight = torch.clamp((weight / scale + zero_point).round(), 0, 255).to(torch.uint8)
|
|
else:
|
|
w_max = weight.abs().amax(dim=reduce_dim, keepdim=True)
|
|
scale = torch.clamp(w_max / 127.0, min=1e-8)
|
|
qweight = torch.clamp((weight / scale).round(), -128, 127).to(torch.int8)
|
|
zero_point = None
|
|
|
|
return qweight, scale.to(torch.float16), zero_point
|
|
|
|
|
|
def apply_quantization(model, use_asymmetric=False, quant_dtype="float32"):
|
|
quant_count = 0
|
|
|
|
def _quantize_module(module, prefix=""):
|
|
nonlocal quant_count
|
|
for name, child in module.named_children():
|
|
full_name = f"{prefix}.{name}" if prefix else name
|
|
|
|
if isinstance(child, (nn.Linear, nn.Conv2d)):
|
|
try:
|
|
W = child.weight.data.float()
|
|
qW, scale, zp = quantize_weight(W, use_asymmetric=use_asymmetric)
|
|
|
|
del child._parameters["weight"]
|
|
child.register_buffer("int8_weight", qW)
|
|
child.register_buffer("scale", scale)
|
|
if zp is not None:
|
|
child.register_buffer("zero_point", zp)
|
|
else:
|
|
child.zero_point = None
|
|
|
|
child.forward = make_quantized_forward(quant_dtype).__get__(child)
|
|
quant_count += 1
|
|
except Exception as e:
|
|
print(f"Failed to quantize {full_name}: {str(e)}")
|
|
|
|
_quantize_module(child, full_name)
|
|
|
|
_quantize_module(model)
|
|
print(f"✅ Successfully quantized {quant_count} layers")
|
|
return model
|
|
|
|
# ---------------------- ComfyUI Node Implementation ------------------------
|
|
class CheckpointLoaderQuantized:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
|
|
"enable_quant": ("BOOLEAN", {"default": True}),
|
|
"use_asymmetric": ("BOOLEAN", {"default": False}),
|
|
"quant_dtype": (["float32", "float16"], {"default": "float32"}), # Toggle for precision
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
|
FUNCTION = "load_quantized"
|
|
CATEGORY = "Loaders (Quantized)"
|
|
OUTPUT_NODE = False
|
|
|
|
def load_quantized(self, ckpt_name, enable_quant, use_asymmetric, quant_dtype):
|
|
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
|
|
|
if not os.path.exists(ckpt_path):
|
|
raise FileNotFoundError(f"Checkpoint {ckpt_name} not found at {ckpt_path}")
|
|
|
|
model_patcher, clip, vae, _ = load_checkpoint_guess_config(
|
|
ckpt_path,
|
|
output_vae=True,
|
|
output_clip=True,
|
|
embedding_directory=folder_paths.get_folder_paths("embeddings")
|
|
)
|
|
|
|
if enable_quant:
|
|
mode = "Asymmetric" if use_asymmetric else "Symmetric"
|
|
print(f"🔧 Applying {mode} 8-bit quantization to {ckpt_name} (dtype={quant_dtype})")
|
|
apply_quantization(model_patcher.model, use_asymmetric=use_asymmetric, quant_dtype=quant_dtype)
|
|
else:
|
|
print(f"🔧 Loading {ckpt_name} without quantization")
|
|
|
|
return (model_patcher, clip, vae)
|
|
|
|
# ------------------------- Node Registration -------------------------------
|
|
NODE_CLASS_MAPPINGS = {
|
|
"CheckpointLoaderQuantized": CheckpointLoaderQuantized,
|
|
}
|
|
|
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
"CheckpointLoaderQuantized": "CFZ Checkpoint Loader",
|
|
}
|
|
|
|
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
|