diff --git a/cfz_checkpoint_loader.py b/cfz_checkpoint_loader.py new file mode 100644 index 000000000..3f8097d83 --- /dev/null +++ b/cfz_checkpoint_loader.py @@ -0,0 +1,139 @@ +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from comfy.sd import load_checkpoint_guess_config, load_checkpoint +from comfy.model_patcher import ModelPatcher +import folder_paths + +# ------------------------ Core Quantization Logic ------------------------- +def make_quantized_forward(quant_dtype="float32"): + def forward(self, x): + dtype = torch.float32 if quant_dtype == "float32" else torch.float16 + + W = self.int8_weight.to(x.device, dtype=dtype) + if hasattr(self, 'zero_point') and self.zero_point is not None: + zp = self.zero_point.to(x.device, dtype=dtype) + W = W.sub(zp).mul(self.scale) + else: + W = W.mul(self.scale) + + bias = self.bias.to(dtype) if self.bias is not None else None + + # LoRA application (if present) + if hasattr(self, "lora_down") and hasattr(self, "lora_up") and hasattr(self, "lora_alpha"): + x = x + self.lora_up(self.lora_down(x)) * self.lora_alpha + + x = x.to(dtype) + + if isinstance(self, nn.Linear): + return F.linear(x, W, bias) + elif isinstance(self, nn.Conv2d): + return F.conv2d(x, W, bias, + self.stride, self.padding, + self.dilation, self.groups) + else: + return x + return forward + + + +def quantize_weight(weight: torch.Tensor, num_bits=8, use_asymmetric=False): + reduce_dim = 1 if weight.ndim == 2 else [i for i in range(weight.ndim) if i != 0] + if use_asymmetric: + min_val = weight.amin(dim=reduce_dim, keepdim=True) + max_val = weight.amax(dim=reduce_dim, keepdim=True) + scale = torch.clamp((max_val - min_val) / 255.0, min=1e-8) + zero_point = torch.clamp((-min_val / scale).round(), 0, 255).to(torch.uint8) + qweight = torch.clamp((weight / scale + zero_point).round(), 0, 255).to(torch.uint8) + else: + w_max = weight.abs().amax(dim=reduce_dim, keepdim=True) + scale = torch.clamp(w_max / 127.0, min=1e-8) + qweight = torch.clamp((weight / scale).round(), -128, 127).to(torch.int8) + zero_point = None + + return qweight, scale.to(torch.float16), zero_point + + +def apply_quantization(model, use_asymmetric=False, quant_dtype="float32"): + quant_count = 0 + + def _quantize_module(module, prefix=""): + nonlocal quant_count + for name, child in module.named_children(): + full_name = f"{prefix}.{name}" if prefix else name + + if isinstance(child, (nn.Linear, nn.Conv2d)): + try: + W = child.weight.data.float() + qW, scale, zp = quantize_weight(W, use_asymmetric=use_asymmetric) + + del child._parameters["weight"] + child.register_buffer("int8_weight", qW) + child.register_buffer("scale", scale) + if zp is not None: + child.register_buffer("zero_point", zp) + else: + child.zero_point = None + + child.forward = make_quantized_forward(quant_dtype).__get__(child) + quant_count += 1 + except Exception as e: + print(f"Failed to quantize {full_name}: {str(e)}") + + _quantize_module(child, full_name) + + _quantize_module(model) + print(f"✅ Successfully quantized {quant_count} layers") + return model + +# ---------------------- ComfyUI Node Implementation ------------------------ +class CheckpointLoaderQuantized: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), + "enable_quant": ("BOOLEAN", {"default": True}), + "use_asymmetric": ("BOOLEAN", {"default": False}), + "quant_dtype": (["float32", "float16"], {"default": "float32"}), # Toggle for precision + } + } + + RETURN_TYPES = ("MODEL", "CLIP", "VAE") + FUNCTION = "load_quantized" + CATEGORY = "Loaders (Quantized)" + OUTPUT_NODE = False + + def load_quantized(self, ckpt_name, enable_quant, use_asymmetric, quant_dtype): + ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) + + if not os.path.exists(ckpt_path): + raise FileNotFoundError(f"Checkpoint {ckpt_name} not found at {ckpt_path}") + + model_patcher, clip, vae, _ = load_checkpoint_guess_config( + ckpt_path, + output_vae=True, + output_clip=True, + embedding_directory=folder_paths.get_folder_paths("embeddings") + ) + + if enable_quant: + mode = "Asymmetric" if use_asymmetric else "Symmetric" + print(f"🔧 Applying {mode} 8-bit quantization to {ckpt_name} (dtype={quant_dtype})") + apply_quantization(model_patcher.model, use_asymmetric=use_asymmetric, quant_dtype=quant_dtype) + else: + print(f"🔧 Loading {ckpt_name} without quantization") + + return (model_patcher, clip, vae) + +# ------------------------- Node Registration ------------------------------- +NODE_CLASS_MAPPINGS = { + "CheckpointLoaderQuantized": CheckpointLoaderQuantized, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "CheckpointLoaderQuantized": "CFZ Checkpoint Loader", +} + +__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']