ComfyUI/cfz/cfz_patcher.py
2025-06-06 18:07:06 +03:00

544 lines
24 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.sd import load_checkpoint_guess_config, load_checkpoint
from comfy.model_patcher import ModelPatcher
import folder_paths
# ------------------------ Optimized Quantization Logic -------------------------
def quantize_input_for_int8_matmul(input_tensor, weight_scale):
"""Quantize input tensor for optimized int8 matrix multiplication"""
# Calculate input scale per batch/sequence dimension
input_scale = input_tensor.abs().amax(dim=-1, keepdim=True) / 127.0
input_scale = torch.clamp(input_scale, min=1e-8)
# Quantize input to int8
quantized_input = torch.clamp(
(input_tensor / input_scale).round(), -128, 127
).to(torch.int8)
# Combine input and weight scales
combined_scale = input_scale * weight_scale
# Flatten tensors for matrix multiplication if needed
original_shape = input_tensor.shape
if input_tensor.dim() > 2:
quantized_input = quantized_input.flatten(0, -2).contiguous()
combined_scale = combined_scale.flatten(0, -2).contiguous()
# Ensure scale precision for accurate computation
if combined_scale.dtype == torch.float16:
combined_scale = combined_scale.to(torch.float32)
return quantized_input, combined_scale, original_shape
def optimized_int8_matmul(input_tensor, quantized_weight, weight_scale, bias=None):
"""Optimized int8 matrix multiplication using torch._int_mm"""
batch_size = input_tensor.numel() // input_tensor.shape[-1]
# Performance threshold: only use optimized path for larger matrices
# This prevents overhead from dominating small computations
if batch_size >= 32 and input_tensor.shape[-1] >= 32:
# Quantize input tensor for int8 computation
q_input, combined_scale, orig_shape = quantize_input_for_int8_matmul(
input_tensor, weight_scale
)
# Perform optimized int8 matrix multiplication
# This is significantly faster than standard floating-point operations
result = torch._int_mm(q_input, quantized_weight)
# Dequantize result back to floating point
result = result.to(combined_scale.dtype) * combined_scale
# Reshape result back to original input dimensions
if len(orig_shape) > 2:
new_shape = list(orig_shape[:-1]) + [quantized_weight.shape[-1]]
result = result.reshape(new_shape)
# Add bias if present
if bias is not None:
result = result + bias
return result
else:
# Fall back to standard dequantization for small matrices
# This avoids quantization overhead when it's not beneficial
dequantized_weight = quantized_weight.to(input_tensor.dtype) * weight_scale
return F.linear(input_tensor, dequantized_weight, bias)
def make_optimized_quantized_forward(quant_dtype="float32", use_int8_matmul=True):
"""Create an optimized quantized forward function for neural network layers"""
def forward(self, x):
# Determine computation precision
dtype = torch.float32 if quant_dtype == "float32" else torch.float16
# Get input device for consistent placement
device = x.device
# Move quantized weights and scales to input device AND dtype
qW = self.int8_weight.to(device)
scale = self.scale.to(device, dtype=dtype)
# Handle zero point for asymmetric quantization
if hasattr(self, 'zero_point') and self.zero_point is not None:
zp = self.zero_point.to(device, dtype=dtype)
else:
zp = None
# Ensure input is in correct precision
x = x.to(dtype)
# Prepare bias if present - ENSURE IT'S ON THE CORRECT DEVICE
bias = None
if self.bias is not None:
bias = self.bias.to(device, dtype=dtype)
# Apply LoRA adaptation if present (before main computation for better accuracy)
lora_output = None
if hasattr(self, "lora_down") and hasattr(self, "lora_up") and hasattr(self, "lora_alpha"):
# Ensure LoRA weights are on correct device
lora_down = self.lora_down.to(device)
lora_up = self.lora_up.to(device)
lora_output = lora_up(lora_down(x)) * self.lora_alpha
# Choose computation path based on layer type and optimization settings
if isinstance(self, nn.Linear):
# Linear layers can use optimized int8 matmul
if (use_int8_matmul and zp is None and
hasattr(self, '_use_optimized_matmul') and self._use_optimized_matmul):
# Use optimized path (only for symmetric quantization)
result = optimized_int8_matmul(x, qW, scale, bias)
else:
# Standard dequantization path
if zp is not None:
# Asymmetric quantization: subtract zero point then scale
W = (qW.to(dtype) - zp) * scale
else:
# Symmetric quantization: just scale
W = qW.to(dtype) * scale
result = F.linear(x, W, bias)
elif isinstance(self, nn.Conv2d):
# Convolution layers use standard dequantization
if zp is not None:
W = (qW.to(dtype) - zp) * scale
else:
W = qW.to(dtype) * scale
result = F.conv2d(x, W, bias, self.stride, self.padding, self.dilation, self.groups)
else:
# Fallback for unsupported layer types
return x
# Add LoRA output if computed
if lora_output is not None:
result = result + lora_output
return result
return forward
def quantize_weight(weight: torch.Tensor, num_bits=8, use_asymmetric=False):
"""Quantize weights with support for both symmetric and asymmetric quantization"""
# Determine reduction dimensions (preserve output channels)
reduce_dim = 1 if weight.ndim == 2 else [i for i in range(weight.ndim) if i != 0]
if use_asymmetric:
# Asymmetric quantization: use full range [0, 255] for uint8
min_val = weight.amin(dim=reduce_dim, keepdim=True)
max_val = weight.amax(dim=reduce_dim, keepdim=True)
scale = torch.clamp((max_val - min_val) / 255.0, min=1e-8)
zero_point = torch.clamp((-min_val / scale).round(), 0, 255).to(torch.uint8)
qweight = torch.clamp((weight / scale + zero_point).round(), 0, 255).to(torch.uint8)
else:
# Symmetric quantization: use range [-127, 127] for int8
w_max = weight.abs().amax(dim=reduce_dim, keepdim=True)
scale = torch.clamp(w_max / 127.0, min=1e-8)
qweight = torch.clamp((weight / scale).round(), -128, 127).to(torch.int8)
zero_point = None
return qweight, scale.to(torch.float16), zero_point
def apply_optimized_quantization(model, use_asymmetric=False, quant_dtype="float32",
use_int8_matmul=True):
"""Apply quantization with optimized inference paths to a neural network model"""
quant_count = 0
def _quantize_module(module, prefix=""):
nonlocal quant_count
for name, child in module.named_children():
full_name = f"{prefix}.{name}" if prefix else name
# Skip text encoder and CLIP-related modules to avoid conditioning issues
if any(skip_name in full_name.lower() for skip_name in
['text_encoder', 'clip', 'embedder', 'conditioner']):
print(f"⏭️ Skipping {full_name} (text/conditioning module)")
_quantize_module(child, full_name)
continue
if isinstance(child, (nn.Linear, nn.Conv2d)):
try:
# Extract and quantize weights
W = child.weight.data.float()
qW, scale, zp = quantize_weight(W, use_asymmetric=use_asymmetric)
# Store original device info before removing weight
original_device = child.weight.device
# Remove original weight parameter to save memory
del child._parameters["weight"]
# Register quantized parameters as buffers (non-trainable)
# Keep them on CPU initially to save GPU memory
child.register_buffer("int8_weight", qW.to(original_device))
child.register_buffer("scale", scale.to(original_device))
if zp is not None:
child.register_buffer("zero_point", zp.to(original_device))
else:
child.zero_point = None
# Configure optimization settings for this layer
if isinstance(child, nn.Linear) and not use_asymmetric and use_int8_matmul:
# Enable optimized matmul for symmetric quantized linear layers
child._use_optimized_matmul = True
# Transpose weight for optimized matmul layout
child.int8_weight = child.int8_weight.transpose(0, 1).contiguous()
# Adjust scale dimensions for matmul
child.scale = child.scale.squeeze(-1)
else:
child._use_optimized_matmul = False
# Assign optimized forward function
child.forward = make_optimized_quantized_forward(
quant_dtype, use_int8_matmul
).__get__(child)
quant_count += 1
opt_status = "optimized" if child._use_optimized_matmul else "standard"
# print(f"✅ Quantized {full_name} ({opt_status})")
except Exception as e:
print(f"❌ Failed to quantize {full_name}: {str(e)}")
# Recursively process child modules
_quantize_module(child, full_name)
_quantize_module(model)
print(f"✅ Successfully quantized {quant_count} layers with optimized inference")
return model
# ---------------------- ComfyUI Node Implementations ------------------------
class CheckpointLoaderQuantized2:
"""Original checkpoint loader with quantization"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
"enable_quant": ("BOOLEAN", {"default": True}),
"use_asymmetric": ("BOOLEAN", {"default": False}),
"quant_dtype": (["float32", "float16"], {"default": "float32"}),
"use_int8_matmul": ("BOOLEAN", {"default": True}),
}
}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_quantized"
CATEGORY = "Loaders (Quantized)"
OUTPUT_NODE = False
def load_quantized(self, ckpt_name, enable_quant, use_asymmetric, quant_dtype,
use_int8_matmul):
"""Load and optionally quantize a checkpoint with optimized inference"""
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
if not os.path.exists(ckpt_path):
raise FileNotFoundError(f"Checkpoint {ckpt_name} not found at {ckpt_path}")
# Load checkpoint using ComfyUI's standard loading mechanism
model_patcher, clip, vae, _ = load_checkpoint_guess_config(
ckpt_path,
output_vae=True,
output_clip=True,
embedding_directory=folder_paths.get_folder_paths("embeddings")
)
if enable_quant:
# Determine quantization configuration
quant_mode = "Asymmetric" if use_asymmetric else "Symmetric"
matmul_mode = "Optimized Int8" if use_int8_matmul and not use_asymmetric else "Standard"
print(f"🔧 Applying {quant_mode} 8-bit quantization to {ckpt_name}")
print(f" MatMul: {matmul_mode}, Forward: Optimized (dtype={quant_dtype})")
# Apply quantization with optimizations
apply_optimized_quantization(
model_patcher.model,
use_asymmetric=use_asymmetric,
quant_dtype=quant_dtype,
use_int8_matmul=use_int8_matmul
)
else:
print(f"🔧 Loading {ckpt_name} without quantization")
return (model_patcher, clip, vae)
class ModelQuantizationPatcher:
"""Quantization patcher that can be applied to any model in the workflow"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL",),
"use_asymmetric": ("BOOLEAN", {"default": False}),
"quant_dtype": (["float32", "float16"], {"default": "float32"}),
"use_int8_matmul": ("BOOLEAN", {"default": True}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch_model"
CATEGORY = "Model Patching"
OUTPUT_NODE = False
def patch_model(self, model, use_asymmetric, quant_dtype, use_int8_matmul):
"""Apply quantization to an existing model"""
# Clone the model to avoid modifying the original
import copy
quantized_model = copy.deepcopy(model)
# Determine quantization configuration
quant_mode = "Asymmetric" if use_asymmetric else "Symmetric"
matmul_mode = "Optimized Int8" if use_int8_matmul and not use_asymmetric else "Standard"
print(f"🔧 Applying {quant_mode} 8-bit quantization to model")
print(f" MatMul: {matmul_mode}, Forward: Optimized (dtype={quant_dtype})")
# Apply quantization with optimizations
apply_optimized_quantization(
quantized_model.model,
use_asymmetric=use_asymmetric,
quant_dtype=quant_dtype,
use_int8_matmul=use_int8_matmul
)
return (quantized_model,)
class UNetQuantizationPatcher:
"""Specialized quantization patcher for UNet models loaded separately"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL",),
"use_asymmetric": ("BOOLEAN", {"default": False}),
"quant_dtype": (["float32", "float16"], {"default": "float32"}),
"use_int8_matmul": ("BOOLEAN", {"default": True}),
"skip_input_blocks": ("BOOLEAN", {"default": False}),
"skip_output_blocks": ("BOOLEAN", {"default": False}),
"show_memory_usage": ("BOOLEAN", {"default": True}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch_unet"
CATEGORY = "Model Patching"
OUTPUT_NODE = False
def get_model_memory_usage(self, model, force_calculation=False):
"""Calculate memory usage of model parameters (CPU + GPU)"""
total_memory = 0
param_count = 0
gpu_memory = 0
# Count all parameters (CPU + GPU)
for param in model.parameters():
memory_bytes = param.data.element_size() * param.data.nelement()
total_memory += memory_bytes
param_count += param.data.nelement()
if param.data.is_cuda:
gpu_memory += memory_bytes
# Also check for quantized buffers
for name, buffer in model.named_buffers():
if 'int8_weight' in name or 'scale' in name or 'zero_point' in name:
memory_bytes = buffer.element_size() * buffer.nelement()
total_memory += memory_bytes
if buffer.is_cuda:
gpu_memory += memory_bytes
# If force_calculation is True and nothing on GPU, return total memory as estimate
if force_calculation and gpu_memory == 0:
return total_memory, param_count, total_memory
return total_memory, param_count, gpu_memory
def format_memory_size(self, bytes_size):
"""Format memory size in human readable format"""
for unit in ['B', 'KB', 'MB', 'GB']:
if bytes_size < 1024.0:
return f"{bytes_size:.2f} {unit}"
bytes_size /= 1024.0
return f"{bytes_size:.2f} TB"
def patch_unet(self, model, use_asymmetric, quant_dtype, use_int8_matmul,
skip_input_blocks, skip_output_blocks, show_memory_usage):
"""Apply selective quantization to UNet model with block-level control"""
import copy
# Measure original memory usage
if show_memory_usage:
original_memory, original_params, original_gpu = self.get_model_memory_usage(model.model, force_calculation=True)
print(f"📊 Original Model Memory Usage:")
print(f" Parameters: {original_params:,}")
print(f" Total Size: {self.format_memory_size(original_memory)}")
if original_gpu > 0:
print(f" GPU Memory: {self.format_memory_size(original_gpu)}")
else:
print(f" GPU Memory: Not loaded (will use ~{self.format_memory_size(original_memory)} when loaded)")
quantized_model = copy.deepcopy(model)
# Determine quantization configuration
quant_mode = "Asymmetric" if use_asymmetric else "Symmetric"
matmul_mode = "Optimized Int8" if use_int8_matmul and not use_asymmetric else "Standard"
print(f"🔧 Applying {quant_mode} 8-bit quantization to UNet")
print(f" MatMul: {matmul_mode}, Forward: Optimized (dtype={quant_dtype})")
if skip_input_blocks or skip_output_blocks:
print(f" Skipping: Input blocks={skip_input_blocks}, Output blocks={skip_output_blocks}")
# Apply quantization with selective skipping
self._apply_selective_quantization(
quantized_model.model,
use_asymmetric=use_asymmetric,
quant_dtype=quant_dtype,
use_int8_matmul=use_int8_matmul,
skip_input_blocks=skip_input_blocks,
skip_output_blocks=skip_output_blocks
)
# Measure quantized memory usage
if show_memory_usage:
quantized_memory, quantized_params, quantized_gpu = self.get_model_memory_usage(quantized_model.model, force_calculation=True)
memory_saved = original_memory - quantized_memory
memory_reduction_pct = (memory_saved / original_memory) * 100 if original_memory > 0 else 0
print(f"📊 Quantized Model Memory Usage:")
print(f" Parameters: {quantized_params:,}")
print(f" Total Size: {self.format_memory_size(quantized_memory)}")
if quantized_gpu > 0:
print(f" GPU Memory: {self.format_memory_size(quantized_gpu)}")
else:
print(f" GPU Memory: Not loaded (will use ~{self.format_memory_size(quantized_memory)} when loaded)")
print(f" Memory Saved: {self.format_memory_size(memory_saved)} ({memory_reduction_pct:.1f}%)")
# Show CUDA memory info if available
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated()
reserved = torch.cuda.memory_reserved()
print(f"📊 Total GPU Memory Status:")
print(f" Currently Allocated: {self.format_memory_size(allocated)}")
print(f" Reserved by PyTorch: {self.format_memory_size(reserved)}")
return (quantized_model,)
def _apply_selective_quantization(self, model, use_asymmetric=False, quant_dtype="float32",
use_int8_matmul=True, skip_input_blocks=False,
skip_output_blocks=False):
"""Apply quantization with selective block skipping for UNet"""
quant_count = 0
def _quantize_module(module, prefix=""):
nonlocal quant_count
for name, child in module.named_children():
full_name = f"{prefix}.{name}" if prefix else name
# Skip blocks based on user preference
if skip_input_blocks and "input_blocks" in full_name:
print(f"⏭️ Skipping {full_name} (input block)")
_quantize_module(child, full_name)
continue
if skip_output_blocks and "output_blocks" in full_name:
print(f"⏭️ Skipping {full_name} (output block)")
_quantize_module(child, full_name)
continue
# Skip text encoder and CLIP-related modules
if any(skip_name in full_name.lower() for skip_name in
['text_encoder', 'clip', 'embedder', 'conditioner']):
print(f" Skipping {full_name} (text/conditioning module)")
_quantize_module(child, full_name)
continue
if isinstance(child, (nn.Linear, nn.Conv2d)):
try:
# Extract and quantize weights
W = child.weight.data.float()
qW, scale, zp = quantize_weight(W, use_asymmetric=use_asymmetric)
# Store original device info before removing weight
original_device = child.weight.device
# Remove original weight parameter to save memory
del child._parameters["weight"]
# Register quantized parameters as buffers (non-trainable)
child.register_buffer("int8_weight", qW.to(original_device))
child.register_buffer("scale", scale.to(original_device))
if zp is not None:
child.register_buffer("zero_point", zp.to(original_device))
else:
child.zero_point = None
# Configure optimization settings for this layer
if isinstance(child, nn.Linear) and not use_asymmetric and use_int8_matmul:
# Enable optimized matmul for symmetric quantized linear layers
child._use_optimized_matmul = True
# Transpose weight for optimized matmul layout
child.int8_weight = child.int8_weight.transpose(0, 1).contiguous()
# Adjust scale dimensions for matmul
child.scale = child.scale.squeeze(-1)
else:
child._use_optimized_matmul = False
# Assign optimized forward function
child.forward = make_optimized_quantized_forward(
quant_dtype, use_int8_matmul
).__get__(child)
quant_count += 1
except Exception as e:
print(f"❌ Failed to quantize {full_name}: {str(e)}")
# Recursively process child modules
_quantize_module(child, full_name)
_quantize_module(model)
print(f"✅ Successfully quantized {quant_count} layers with selective patching")
# ------------------------- Node Registration -------------------------------
NODE_CLASS_MAPPINGS = {
"CheckpointLoaderQuantized2": CheckpointLoaderQuantized2,
"ModelQuantizationPatcher": ModelQuantizationPatcher,
"UNetQuantizationPatcher": UNetQuantizationPatcher,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CheckpointLoaderQuantized2": "CFZ Checkpoint Loader (Optimized)",
"ModelQuantizationPatcher": "CFZ Model Quantization Patcher",
"UNetQuantizationPatcher": "CFZ UNet Quantization Patcher",
}
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']