diff --git a/custom_nodes/cfz_checkpoint_loader.py b/custom_nodes/cfz_checkpoint_loader.py deleted file mode 100644 index 3f8097d83..000000000 --- a/custom_nodes/cfz_checkpoint_loader.py +++ /dev/null @@ -1,139 +0,0 @@ -import os -import torch -import torch.nn as nn -import torch.nn.functional as F -from comfy.sd import load_checkpoint_guess_config, load_checkpoint -from comfy.model_patcher import ModelPatcher -import folder_paths - -# ------------------------ Core Quantization Logic ------------------------- -def make_quantized_forward(quant_dtype="float32"): - def forward(self, x): - dtype = torch.float32 if quant_dtype == "float32" else torch.float16 - - W = self.int8_weight.to(x.device, dtype=dtype) - if hasattr(self, 'zero_point') and self.zero_point is not None: - zp = self.zero_point.to(x.device, dtype=dtype) - W = W.sub(zp).mul(self.scale) - else: - W = W.mul(self.scale) - - bias = self.bias.to(dtype) if self.bias is not None else None - - # LoRA application (if present) - if hasattr(self, "lora_down") and hasattr(self, "lora_up") and hasattr(self, "lora_alpha"): - x = x + self.lora_up(self.lora_down(x)) * self.lora_alpha - - x = x.to(dtype) - - if isinstance(self, nn.Linear): - return F.linear(x, W, bias) - elif isinstance(self, nn.Conv2d): - return F.conv2d(x, W, bias, - self.stride, self.padding, - self.dilation, self.groups) - else: - return x - return forward - - - -def quantize_weight(weight: torch.Tensor, num_bits=8, use_asymmetric=False): - reduce_dim = 1 if weight.ndim == 2 else [i for i in range(weight.ndim) if i != 0] - if use_asymmetric: - min_val = weight.amin(dim=reduce_dim, keepdim=True) - max_val = weight.amax(dim=reduce_dim, keepdim=True) - scale = torch.clamp((max_val - min_val) / 255.0, min=1e-8) - zero_point = torch.clamp((-min_val / scale).round(), 0, 255).to(torch.uint8) - qweight = torch.clamp((weight / scale + zero_point).round(), 0, 255).to(torch.uint8) - else: - w_max = weight.abs().amax(dim=reduce_dim, keepdim=True) - scale = torch.clamp(w_max / 127.0, min=1e-8) - qweight = torch.clamp((weight / scale).round(), -128, 127).to(torch.int8) - zero_point = None - - return qweight, scale.to(torch.float16), zero_point - - -def apply_quantization(model, use_asymmetric=False, quant_dtype="float32"): - quant_count = 0 - - def _quantize_module(module, prefix=""): - nonlocal quant_count - for name, child in module.named_children(): - full_name = f"{prefix}.{name}" if prefix else name - - if isinstance(child, (nn.Linear, nn.Conv2d)): - try: - W = child.weight.data.float() - qW, scale, zp = quantize_weight(W, use_asymmetric=use_asymmetric) - - del child._parameters["weight"] - child.register_buffer("int8_weight", qW) - child.register_buffer("scale", scale) - if zp is not None: - child.register_buffer("zero_point", zp) - else: - child.zero_point = None - - child.forward = make_quantized_forward(quant_dtype).__get__(child) - quant_count += 1 - except Exception as e: - print(f"Failed to quantize {full_name}: {str(e)}") - - _quantize_module(child, full_name) - - _quantize_module(model) - print(f"✅ Successfully quantized {quant_count} layers") - return model - -# ---------------------- ComfyUI Node Implementation ------------------------ -class CheckpointLoaderQuantized: - @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), - "enable_quant": ("BOOLEAN", {"default": True}), - "use_asymmetric": ("BOOLEAN", {"default": False}), - "quant_dtype": (["float32", "float16"], {"default": "float32"}), # Toggle for precision - } - } - - RETURN_TYPES = ("MODEL", "CLIP", "VAE") - FUNCTION = "load_quantized" - CATEGORY = "Loaders (Quantized)" - OUTPUT_NODE = False - - def load_quantized(self, ckpt_name, enable_quant, use_asymmetric, quant_dtype): - ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) - - if not os.path.exists(ckpt_path): - raise FileNotFoundError(f"Checkpoint {ckpt_name} not found at {ckpt_path}") - - model_patcher, clip, vae, _ = load_checkpoint_guess_config( - ckpt_path, - output_vae=True, - output_clip=True, - embedding_directory=folder_paths.get_folder_paths("embeddings") - ) - - if enable_quant: - mode = "Asymmetric" if use_asymmetric else "Symmetric" - print(f"🔧 Applying {mode} 8-bit quantization to {ckpt_name} (dtype={quant_dtype})") - apply_quantization(model_patcher.model, use_asymmetric=use_asymmetric, quant_dtype=quant_dtype) - else: - print(f"🔧 Loading {ckpt_name} without quantization") - - return (model_patcher, clip, vae) - -# ------------------------- Node Registration ------------------------------- -NODE_CLASS_MAPPINGS = { - "CheckpointLoaderQuantized": CheckpointLoaderQuantized, -} - -NODE_DISPLAY_NAME_MAPPINGS = { - "CheckpointLoaderQuantized": "CFZ Checkpoint Loader", -} - -__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']