mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Add files via upload
This commit is contained in:
parent
5b5165371e
commit
a8ddea2271
139
custom_nodes/cfz_checkpoint_loader.py
Normal file
139
custom_nodes/cfz_checkpoint_loader.py
Normal file
@ -0,0 +1,139 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.sd import load_checkpoint_guess_config, load_checkpoint
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
import folder_paths
|
||||
|
||||
# ------------------------ Core Quantization Logic -------------------------
|
||||
def make_quantized_forward(quant_dtype="float32"):
|
||||
def forward(self, x):
|
||||
dtype = torch.float32 if quant_dtype == "float32" else torch.float16
|
||||
|
||||
W = self.int8_weight.to(x.device, dtype=dtype)
|
||||
if hasattr(self, 'zero_point') and self.zero_point is not None:
|
||||
zp = self.zero_point.to(x.device, dtype=dtype)
|
||||
W = W.sub(zp).mul(self.scale)
|
||||
else:
|
||||
W = W.mul(self.scale)
|
||||
|
||||
bias = self.bias.to(dtype) if self.bias is not None else None
|
||||
|
||||
# LoRA application (if present)
|
||||
if hasattr(self, "lora_down") and hasattr(self, "lora_up") and hasattr(self, "lora_alpha"):
|
||||
x = x + self.lora_up(self.lora_down(x)) * self.lora_alpha
|
||||
|
||||
x = x.to(dtype)
|
||||
|
||||
if isinstance(self, nn.Linear):
|
||||
return F.linear(x, W, bias)
|
||||
elif isinstance(self, nn.Conv2d):
|
||||
return F.conv2d(x, W, bias,
|
||||
self.stride, self.padding,
|
||||
self.dilation, self.groups)
|
||||
else:
|
||||
return x
|
||||
return forward
|
||||
|
||||
|
||||
|
||||
def quantize_weight(weight: torch.Tensor, num_bits=8, use_asymmetric=False):
|
||||
reduce_dim = 1 if weight.ndim == 2 else [i for i in range(weight.ndim) if i != 0]
|
||||
if use_asymmetric:
|
||||
min_val = weight.amin(dim=reduce_dim, keepdim=True)
|
||||
max_val = weight.amax(dim=reduce_dim, keepdim=True)
|
||||
scale = torch.clamp((max_val - min_val) / 255.0, min=1e-8)
|
||||
zero_point = torch.clamp((-min_val / scale).round(), 0, 255).to(torch.uint8)
|
||||
qweight = torch.clamp((weight / scale + zero_point).round(), 0, 255).to(torch.uint8)
|
||||
else:
|
||||
w_max = weight.abs().amax(dim=reduce_dim, keepdim=True)
|
||||
scale = torch.clamp(w_max / 127.0, min=1e-8)
|
||||
qweight = torch.clamp((weight / scale).round(), -128, 127).to(torch.int8)
|
||||
zero_point = None
|
||||
|
||||
return qweight, scale.to(torch.float16), zero_point
|
||||
|
||||
|
||||
def apply_quantization(model, use_asymmetric=False, quant_dtype="float32"):
|
||||
quant_count = 0
|
||||
|
||||
def _quantize_module(module, prefix=""):
|
||||
nonlocal quant_count
|
||||
for name, child in module.named_children():
|
||||
full_name = f"{prefix}.{name}" if prefix else name
|
||||
|
||||
if isinstance(child, (nn.Linear, nn.Conv2d)):
|
||||
try:
|
||||
W = child.weight.data.float()
|
||||
qW, scale, zp = quantize_weight(W, use_asymmetric=use_asymmetric)
|
||||
|
||||
del child._parameters["weight"]
|
||||
child.register_buffer("int8_weight", qW)
|
||||
child.register_buffer("scale", scale)
|
||||
if zp is not None:
|
||||
child.register_buffer("zero_point", zp)
|
||||
else:
|
||||
child.zero_point = None
|
||||
|
||||
child.forward = make_quantized_forward(quant_dtype).__get__(child)
|
||||
quant_count += 1
|
||||
except Exception as e:
|
||||
print(f"Failed to quantize {full_name}: {str(e)}")
|
||||
|
||||
_quantize_module(child, full_name)
|
||||
|
||||
_quantize_module(model)
|
||||
print(f"✅ Successfully quantized {quant_count} layers")
|
||||
return model
|
||||
|
||||
# ---------------------- ComfyUI Node Implementation ------------------------
|
||||
class CheckpointLoaderQuantized:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
|
||||
"enable_quant": ("BOOLEAN", {"default": True}),
|
||||
"use_asymmetric": ("BOOLEAN", {"default": False}),
|
||||
"quant_dtype": (["float32", "float16"], {"default": "float32"}), # Toggle for precision
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||
FUNCTION = "load_quantized"
|
||||
CATEGORY = "Loaders (Quantized)"
|
||||
OUTPUT_NODE = False
|
||||
|
||||
def load_quantized(self, ckpt_name, enable_quant, use_asymmetric, quant_dtype):
|
||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||
|
||||
if not os.path.exists(ckpt_path):
|
||||
raise FileNotFoundError(f"Checkpoint {ckpt_name} not found at {ckpt_path}")
|
||||
|
||||
model_patcher, clip, vae, _ = load_checkpoint_guess_config(
|
||||
ckpt_path,
|
||||
output_vae=True,
|
||||
output_clip=True,
|
||||
embedding_directory=folder_paths.get_folder_paths("embeddings")
|
||||
)
|
||||
|
||||
if enable_quant:
|
||||
mode = "Asymmetric" if use_asymmetric else "Symmetric"
|
||||
print(f"🔧 Applying {mode} 8-bit quantization to {ckpt_name} (dtype={quant_dtype})")
|
||||
apply_quantization(model_patcher.model, use_asymmetric=use_asymmetric, quant_dtype=quant_dtype)
|
||||
else:
|
||||
print(f"🔧 Loading {ckpt_name} without quantization")
|
||||
|
||||
return (model_patcher, clip, vae)
|
||||
|
||||
# ------------------------- Node Registration -------------------------------
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CheckpointLoaderQuantized": CheckpointLoaderQuantized,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CheckpointLoaderQuantized": "CFZ Checkpoint Loader",
|
||||
}
|
||||
|
||||
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
|
||||
Loading…
Reference in New Issue
Block a user