mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Add files via upload
This commit is contained in:
parent
f599706b6b
commit
74545c2acd
59
cfz/nodes/cfz_cudnn.toggle.py
Normal file
59
cfz/nodes/cfz_cudnn.toggle.py
Normal file
@ -0,0 +1,59 @@
|
||||
import torch
|
||||
|
||||
class AnyType(str):
|
||||
"""A special class that is always equal in not equal comparisons. Credit to pythongosssss"""
|
||||
|
||||
def __ne__(self, __value: object) -> bool:
|
||||
return False
|
||||
|
||||
anyType = AnyType("*")
|
||||
|
||||
class CUDNNToggleAutoPassthrough:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"optional": {
|
||||
"model": ("MODEL",),
|
||||
"conditioning": ("CONDITIONING",),
|
||||
"latent": ("LATENT",),
|
||||
"audio": ("AUDIO",),
|
||||
"image": ("IMAGE",),
|
||||
"wan_model": ("WANVIDEOMODEL",),
|
||||
"any_input": (anyType, {}),
|
||||
},
|
||||
"required": {
|
||||
"enable_cudnn": ("BOOLEAN", {"default": True}),
|
||||
"cudnn_benchmark": ("BOOLEAN", {"default": False}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CONDITIONING", "LATENT", "AUDIO", "IMAGE", "WANVIDEOMODEL", anyType, "BOOLEAN", "BOOLEAN")
|
||||
RETURN_NAMES = ("model", "conditioning", "latent", "audio", "image", "wan_model", "any_output", "prev_cudnn", "prev_benchmark")
|
||||
FUNCTION = "toggle"
|
||||
CATEGORY = "utils"
|
||||
|
||||
def toggle(self, enable_cudnn, cudnn_benchmark, any_input=None, wan_model=None, model=None, conditioning=None, latent=None, audio=None, image=None):
|
||||
prev_cudnn = torch.backends.cudnn.enabled
|
||||
prev_benchmark = torch.backends.cudnn.benchmark
|
||||
torch.backends.cudnn.enabled = enable_cudnn
|
||||
torch.backends.cudnn.benchmark = cudnn_benchmark
|
||||
if enable_cudnn != prev_cudnn:
|
||||
print(f"[CUDNN_TOGGLE] torch.backends.cudnn.enabled set to {enable_cudnn} (was {prev_cudnn})")
|
||||
else:
|
||||
print(f"[CUDNN_TOGGLE] torch.backends.cudnn.enabled still set to {enable_cudnn}")
|
||||
|
||||
if cudnn_benchmark != prev_benchmark:
|
||||
print(f"[CUDNN_TOGGLE] torch.backends.cudnn.benchmark set to {cudnn_benchmark} (was {prev_benchmark})")
|
||||
else:
|
||||
print(f"[CUDNN_TOGGLE] torch.backends.cudnn.benchmark still set to {cudnn_benchmark}")
|
||||
|
||||
return_tuple = (model, conditioning, latent, audio, image, wan_model, any_input, prev_cudnn, prev_benchmark)
|
||||
return return_tuple
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CUDNNToggleAutoPassthrough": CUDNNToggleAutoPassthrough
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CUDNNToggleAutoPassthrough": "CFZ CUDNN Toggle"
|
||||
}
|
||||
543
cfz/nodes/cfz_patcher.py
Normal file
543
cfz/nodes/cfz_patcher.py
Normal file
@ -0,0 +1,543 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.sd import load_checkpoint_guess_config, load_checkpoint
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
import folder_paths
|
||||
|
||||
# ------------------------ Optimized Quantization Logic -------------------------
|
||||
def quantize_input_for_int8_matmul(input_tensor, weight_scale):
|
||||
"""Quantize input tensor for optimized int8 matrix multiplication"""
|
||||
# Calculate input scale per batch/sequence dimension
|
||||
input_scale = input_tensor.abs().amax(dim=-1, keepdim=True) / 127.0
|
||||
input_scale = torch.clamp(input_scale, min=1e-8)
|
||||
|
||||
# Quantize input to int8
|
||||
quantized_input = torch.clamp(
|
||||
(input_tensor / input_scale).round(), -128, 127
|
||||
).to(torch.int8)
|
||||
|
||||
# Combine input and weight scales
|
||||
combined_scale = input_scale * weight_scale
|
||||
|
||||
# Flatten tensors for matrix multiplication if needed
|
||||
original_shape = input_tensor.shape
|
||||
if input_tensor.dim() > 2:
|
||||
quantized_input = quantized_input.flatten(0, -2).contiguous()
|
||||
combined_scale = combined_scale.flatten(0, -2).contiguous()
|
||||
# Ensure scale precision for accurate computation
|
||||
if combined_scale.dtype == torch.float16:
|
||||
combined_scale = combined_scale.to(torch.float32)
|
||||
|
||||
return quantized_input, combined_scale, original_shape
|
||||
|
||||
def optimized_int8_matmul(input_tensor, quantized_weight, weight_scale, bias=None):
|
||||
"""Optimized int8 matrix multiplication using torch._int_mm"""
|
||||
batch_size = input_tensor.numel() // input_tensor.shape[-1]
|
||||
|
||||
# Performance threshold: only use optimized path for larger matrices
|
||||
# This prevents overhead from dominating small computations
|
||||
if batch_size >= 32 and input_tensor.shape[-1] >= 32:
|
||||
# Quantize input tensor for int8 computation
|
||||
q_input, combined_scale, orig_shape = quantize_input_for_int8_matmul(
|
||||
input_tensor, weight_scale
|
||||
)
|
||||
|
||||
# Perform optimized int8 matrix multiplication
|
||||
# This is significantly faster than standard floating-point operations
|
||||
result = torch._int_mm(q_input, quantized_weight)
|
||||
|
||||
# Dequantize result back to floating point
|
||||
result = result.to(combined_scale.dtype) * combined_scale
|
||||
|
||||
# Reshape result back to original input dimensions
|
||||
if len(orig_shape) > 2:
|
||||
new_shape = list(orig_shape[:-1]) + [quantized_weight.shape[-1]]
|
||||
result = result.reshape(new_shape)
|
||||
|
||||
# Add bias if present
|
||||
if bias is not None:
|
||||
result = result + bias
|
||||
|
||||
return result
|
||||
else:
|
||||
# Fall back to standard dequantization for small matrices
|
||||
# This avoids quantization overhead when it's not beneficial
|
||||
dequantized_weight = quantized_weight.to(input_tensor.dtype) * weight_scale
|
||||
return F.linear(input_tensor, dequantized_weight, bias)
|
||||
|
||||
def make_optimized_quantized_forward(quant_dtype="float32", use_int8_matmul=True):
|
||||
"""Create an optimized quantized forward function for neural network layers"""
|
||||
def forward(self, x):
|
||||
# Determine computation precision
|
||||
dtype = torch.float32 if quant_dtype == "float32" else torch.float16
|
||||
|
||||
# Get input device for consistent placement
|
||||
device = x.device
|
||||
|
||||
# Move quantized weights and scales to input device AND dtype
|
||||
qW = self.int8_weight.to(device)
|
||||
scale = self.scale.to(device, dtype=dtype)
|
||||
|
||||
# Handle zero point for asymmetric quantization
|
||||
if hasattr(self, 'zero_point') and self.zero_point is not None:
|
||||
zp = self.zero_point.to(device, dtype=dtype)
|
||||
else:
|
||||
zp = None
|
||||
|
||||
# Ensure input is in correct precision
|
||||
x = x.to(dtype)
|
||||
|
||||
# Prepare bias if present - ENSURE IT'S ON THE CORRECT DEVICE
|
||||
bias = None
|
||||
if self.bias is not None:
|
||||
bias = self.bias.to(device, dtype=dtype)
|
||||
|
||||
# Apply LoRA adaptation if present (before main computation for better accuracy)
|
||||
lora_output = None
|
||||
if hasattr(self, "lora_down") and hasattr(self, "lora_up") and hasattr(self, "lora_alpha"):
|
||||
# Ensure LoRA weights are on correct device
|
||||
lora_down = self.lora_down.to(device)
|
||||
lora_up = self.lora_up.to(device)
|
||||
lora_output = lora_up(lora_down(x)) * self.lora_alpha
|
||||
|
||||
# Choose computation path based on layer type and optimization settings
|
||||
if isinstance(self, nn.Linear):
|
||||
# Linear layers can use optimized int8 matmul
|
||||
if (use_int8_matmul and zp is None and
|
||||
hasattr(self, '_use_optimized_matmul') and self._use_optimized_matmul):
|
||||
# Use optimized path (only for symmetric quantization)
|
||||
result = optimized_int8_matmul(x, qW, scale, bias)
|
||||
else:
|
||||
# Standard dequantization path
|
||||
if zp is not None:
|
||||
# Asymmetric quantization: subtract zero point then scale
|
||||
W = (qW.to(dtype) - zp) * scale
|
||||
else:
|
||||
# Symmetric quantization: just scale
|
||||
W = qW.to(dtype) * scale
|
||||
result = F.linear(x, W, bias)
|
||||
|
||||
elif isinstance(self, nn.Conv2d):
|
||||
# Convolution layers use standard dequantization
|
||||
if zp is not None:
|
||||
W = (qW.to(dtype) - zp) * scale
|
||||
else:
|
||||
W = qW.to(dtype) * scale
|
||||
result = F.conv2d(x, W, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
else:
|
||||
# Fallback for unsupported layer types
|
||||
return x
|
||||
|
||||
# Add LoRA output if computed
|
||||
if lora_output is not None:
|
||||
result = result + lora_output
|
||||
|
||||
return result
|
||||
|
||||
return forward
|
||||
|
||||
def quantize_weight(weight: torch.Tensor, num_bits=8, use_asymmetric=False):
|
||||
"""Quantize weights with support for both symmetric and asymmetric quantization"""
|
||||
# Determine reduction dimensions (preserve output channels)
|
||||
reduce_dim = 1 if weight.ndim == 2 else [i for i in range(weight.ndim) if i != 0]
|
||||
|
||||
if use_asymmetric:
|
||||
# Asymmetric quantization: use full range [0, 255] for uint8
|
||||
min_val = weight.amin(dim=reduce_dim, keepdim=True)
|
||||
max_val = weight.amax(dim=reduce_dim, keepdim=True)
|
||||
scale = torch.clamp((max_val - min_val) / 255.0, min=1e-8)
|
||||
zero_point = torch.clamp((-min_val / scale).round(), 0, 255).to(torch.uint8)
|
||||
qweight = torch.clamp((weight / scale + zero_point).round(), 0, 255).to(torch.uint8)
|
||||
else:
|
||||
# Symmetric quantization: use range [-127, 127] for int8
|
||||
w_max = weight.abs().amax(dim=reduce_dim, keepdim=True)
|
||||
scale = torch.clamp(w_max / 127.0, min=1e-8)
|
||||
qweight = torch.clamp((weight / scale).round(), -128, 127).to(torch.int8)
|
||||
zero_point = None
|
||||
|
||||
return qweight, scale.to(torch.float16), zero_point
|
||||
|
||||
def apply_optimized_quantization(model, use_asymmetric=False, quant_dtype="float32",
|
||||
use_int8_matmul=True):
|
||||
"""Apply quantization with optimized inference paths to a neural network model"""
|
||||
quant_count = 0
|
||||
|
||||
def _quantize_module(module, prefix=""):
|
||||
nonlocal quant_count
|
||||
for name, child in module.named_children():
|
||||
full_name = f"{prefix}.{name}" if prefix else name
|
||||
|
||||
# Skip text encoder and CLIP-related modules to avoid conditioning issues
|
||||
if any(skip_name in full_name.lower() for skip_name in
|
||||
['text_encoder', 'clip', 'embedder', 'conditioner']):
|
||||
print(f"⏭️ Skipping {full_name} (text/conditioning module)")
|
||||
_quantize_module(child, full_name)
|
||||
continue
|
||||
|
||||
|
||||
if isinstance(child, (nn.Linear, nn.Conv2d)):
|
||||
try:
|
||||
# Extract and quantize weights
|
||||
W = child.weight.data.float()
|
||||
qW, scale, zp = quantize_weight(W, use_asymmetric=use_asymmetric)
|
||||
|
||||
# Store original device info before removing weight
|
||||
original_device = child.weight.device
|
||||
|
||||
# Remove original weight parameter to save memory
|
||||
del child._parameters["weight"]
|
||||
|
||||
# Register quantized parameters as buffers (non-trainable)
|
||||
# Keep them on CPU initially to save GPU memory
|
||||
child.register_buffer("int8_weight", qW.to(original_device))
|
||||
child.register_buffer("scale", scale.to(original_device))
|
||||
if zp is not None:
|
||||
child.register_buffer("zero_point", zp.to(original_device))
|
||||
else:
|
||||
child.zero_point = None
|
||||
|
||||
# Configure optimization settings for this layer
|
||||
if isinstance(child, nn.Linear) and not use_asymmetric and use_int8_matmul:
|
||||
# Enable optimized matmul for symmetric quantized linear layers
|
||||
child._use_optimized_matmul = True
|
||||
# Transpose weight for optimized matmul layout
|
||||
child.int8_weight = child.int8_weight.transpose(0, 1).contiguous()
|
||||
# Adjust scale dimensions for matmul
|
||||
child.scale = child.scale.squeeze(-1)
|
||||
else:
|
||||
child._use_optimized_matmul = False
|
||||
|
||||
# Assign optimized forward function
|
||||
child.forward = make_optimized_quantized_forward(
|
||||
quant_dtype, use_int8_matmul
|
||||
).__get__(child)
|
||||
|
||||
quant_count += 1
|
||||
opt_status = "optimized" if child._use_optimized_matmul else "standard"
|
||||
# print(f"✅ Quantized {full_name} ({opt_status})")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to quantize {full_name}: {str(e)}")
|
||||
|
||||
# Recursively process child modules
|
||||
_quantize_module(child, full_name)
|
||||
|
||||
_quantize_module(model)
|
||||
print(f"✅ Successfully quantized {quant_count} layers with optimized inference")
|
||||
return model
|
||||
|
||||
# ---------------------- ComfyUI Node Implementations ------------------------
|
||||
|
||||
class CheckpointLoaderQuantized2:
|
||||
"""Original checkpoint loader with quantization"""
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
|
||||
"enable_quant": ("BOOLEAN", {"default": True}),
|
||||
"use_asymmetric": ("BOOLEAN", {"default": False}),
|
||||
"quant_dtype": (["float32", "float16"], {"default": "float32"}),
|
||||
"use_int8_matmul": ("BOOLEAN", {"default": True}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||
FUNCTION = "load_quantized"
|
||||
CATEGORY = "Loaders (Quantized)"
|
||||
OUTPUT_NODE = False
|
||||
|
||||
def load_quantized(self, ckpt_name, enable_quant, use_asymmetric, quant_dtype,
|
||||
use_int8_matmul):
|
||||
"""Load and optionally quantize a checkpoint with optimized inference"""
|
||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||
|
||||
if not os.path.exists(ckpt_path):
|
||||
raise FileNotFoundError(f"Checkpoint {ckpt_name} not found at {ckpt_path}")
|
||||
|
||||
# Load checkpoint using ComfyUI's standard loading mechanism
|
||||
model_patcher, clip, vae, _ = load_checkpoint_guess_config(
|
||||
ckpt_path,
|
||||
output_vae=True,
|
||||
output_clip=True,
|
||||
embedding_directory=folder_paths.get_folder_paths("embeddings")
|
||||
)
|
||||
|
||||
if enable_quant:
|
||||
# Determine quantization configuration
|
||||
quant_mode = "Asymmetric" if use_asymmetric else "Symmetric"
|
||||
matmul_mode = "Optimized Int8" if use_int8_matmul and not use_asymmetric else "Standard"
|
||||
|
||||
print(f"🔧 Applying {quant_mode} 8-bit quantization to {ckpt_name}")
|
||||
print(f" MatMul: {matmul_mode}, Forward: Optimized (dtype={quant_dtype})")
|
||||
|
||||
# Apply quantization with optimizations
|
||||
apply_optimized_quantization(
|
||||
model_patcher.model,
|
||||
use_asymmetric=use_asymmetric,
|
||||
quant_dtype=quant_dtype,
|
||||
use_int8_matmul=use_int8_matmul
|
||||
)
|
||||
else:
|
||||
print(f"🔧 Loading {ckpt_name} without quantization")
|
||||
|
||||
return (model_patcher, clip, vae)
|
||||
|
||||
|
||||
class ModelQuantizationPatcher:
|
||||
"""Quantization patcher that can be applied to any model in the workflow"""
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"use_asymmetric": ("BOOLEAN", {"default": False}),
|
||||
"quant_dtype": (["float32", "float16"], {"default": "float32"}),
|
||||
"use_int8_matmul": ("BOOLEAN", {"default": True}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch_model"
|
||||
CATEGORY = "Model Patching"
|
||||
OUTPUT_NODE = False
|
||||
|
||||
def patch_model(self, model, use_asymmetric, quant_dtype, use_int8_matmul):
|
||||
"""Apply quantization to an existing model"""
|
||||
# Clone the model to avoid modifying the original
|
||||
import copy
|
||||
quantized_model = copy.deepcopy(model)
|
||||
|
||||
# Determine quantization configuration
|
||||
quant_mode = "Asymmetric" if use_asymmetric else "Symmetric"
|
||||
matmul_mode = "Optimized Int8" if use_int8_matmul and not use_asymmetric else "Standard"
|
||||
|
||||
print(f"🔧 Applying {quant_mode} 8-bit quantization to model")
|
||||
print(f" MatMul: {matmul_mode}, Forward: Optimized (dtype={quant_dtype})")
|
||||
|
||||
# Apply quantization with optimizations
|
||||
apply_optimized_quantization(
|
||||
quantized_model.model,
|
||||
use_asymmetric=use_asymmetric,
|
||||
quant_dtype=quant_dtype,
|
||||
use_int8_matmul=use_int8_matmul
|
||||
)
|
||||
|
||||
return (quantized_model,)
|
||||
|
||||
|
||||
class UNetQuantizationPatcher:
|
||||
"""Specialized quantization patcher for UNet models loaded separately"""
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"use_asymmetric": ("BOOLEAN", {"default": False}),
|
||||
"quant_dtype": (["float32", "float16"], {"default": "float32"}),
|
||||
"use_int8_matmul": ("BOOLEAN", {"default": True}),
|
||||
"skip_input_blocks": ("BOOLEAN", {"default": False}),
|
||||
"skip_output_blocks": ("BOOLEAN", {"default": False}),
|
||||
"show_memory_usage": ("BOOLEAN", {"default": True}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch_unet"
|
||||
CATEGORY = "Model Patching"
|
||||
OUTPUT_NODE = False
|
||||
|
||||
def get_model_memory_usage(self, model, force_calculation=False):
|
||||
"""Calculate memory usage of model parameters (CPU + GPU)"""
|
||||
total_memory = 0
|
||||
param_count = 0
|
||||
gpu_memory = 0
|
||||
|
||||
# Count all parameters (CPU + GPU)
|
||||
for param in model.parameters():
|
||||
memory_bytes = param.data.element_size() * param.data.nelement()
|
||||
total_memory += memory_bytes
|
||||
param_count += param.data.nelement()
|
||||
|
||||
if param.data.is_cuda:
|
||||
gpu_memory += memory_bytes
|
||||
|
||||
# Also check for quantized buffers
|
||||
for name, buffer in model.named_buffers():
|
||||
if 'int8_weight' in name or 'scale' in name or 'zero_point' in name:
|
||||
memory_bytes = buffer.element_size() * buffer.nelement()
|
||||
total_memory += memory_bytes
|
||||
|
||||
if buffer.is_cuda:
|
||||
gpu_memory += memory_bytes
|
||||
|
||||
# If force_calculation is True and nothing on GPU, return total memory as estimate
|
||||
if force_calculation and gpu_memory == 0:
|
||||
return total_memory, param_count, total_memory
|
||||
|
||||
return total_memory, param_count, gpu_memory
|
||||
|
||||
def format_memory_size(self, bytes_size):
|
||||
"""Format memory size in human readable format"""
|
||||
for unit in ['B', 'KB', 'MB', 'GB']:
|
||||
if bytes_size < 1024.0:
|
||||
return f"{bytes_size:.2f} {unit}"
|
||||
bytes_size /= 1024.0
|
||||
return f"{bytes_size:.2f} TB"
|
||||
|
||||
def patch_unet(self, model, use_asymmetric, quant_dtype, use_int8_matmul,
|
||||
skip_input_blocks, skip_output_blocks, show_memory_usage):
|
||||
"""Apply selective quantization to UNet model with block-level control"""
|
||||
import copy
|
||||
|
||||
# Measure original memory usage
|
||||
if show_memory_usage:
|
||||
original_memory, original_params, original_gpu = self.get_model_memory_usage(model.model, force_calculation=True)
|
||||
print(f"📊 Original Model Memory Usage:")
|
||||
print(f" Parameters: {original_params:,}")
|
||||
print(f" Total Size: {self.format_memory_size(original_memory)}")
|
||||
if original_gpu > 0:
|
||||
print(f" GPU Memory: {self.format_memory_size(original_gpu)}")
|
||||
else:
|
||||
print(f" GPU Memory: Not loaded (will use ~{self.format_memory_size(original_memory)} when loaded)")
|
||||
|
||||
quantized_model = copy.deepcopy(model)
|
||||
|
||||
# Determine quantization configuration
|
||||
quant_mode = "Asymmetric" if use_asymmetric else "Symmetric"
|
||||
matmul_mode = "Optimized Int8" if use_int8_matmul and not use_asymmetric else "Standard"
|
||||
|
||||
print(f"🔧 Applying {quant_mode} 8-bit quantization to UNet")
|
||||
print(f" MatMul: {matmul_mode}, Forward: Optimized (dtype={quant_dtype})")
|
||||
|
||||
if skip_input_blocks or skip_output_blocks:
|
||||
print(f" Skipping: Input blocks={skip_input_blocks}, Output blocks={skip_output_blocks}")
|
||||
|
||||
# Apply quantization with selective skipping
|
||||
self._apply_selective_quantization(
|
||||
quantized_model.model,
|
||||
use_asymmetric=use_asymmetric,
|
||||
quant_dtype=quant_dtype,
|
||||
use_int8_matmul=use_int8_matmul,
|
||||
skip_input_blocks=skip_input_blocks,
|
||||
skip_output_blocks=skip_output_blocks
|
||||
)
|
||||
|
||||
# Measure quantized memory usage
|
||||
if show_memory_usage:
|
||||
quantized_memory, quantized_params, quantized_gpu = self.get_model_memory_usage(quantized_model.model, force_calculation=True)
|
||||
memory_saved = original_memory - quantized_memory
|
||||
memory_reduction_pct = (memory_saved / original_memory) * 100 if original_memory > 0 else 0
|
||||
|
||||
print(f"📊 Quantized Model Memory Usage:")
|
||||
print(f" Parameters: {quantized_params:,}")
|
||||
print(f" Total Size: {self.format_memory_size(quantized_memory)}")
|
||||
if quantized_gpu > 0:
|
||||
print(f" GPU Memory: {self.format_memory_size(quantized_gpu)}")
|
||||
else:
|
||||
print(f" GPU Memory: Not loaded (will use ~{self.format_memory_size(quantized_memory)} when loaded)")
|
||||
print(f" Memory Saved: {self.format_memory_size(memory_saved)} ({memory_reduction_pct:.1f}%)")
|
||||
|
||||
# Show CUDA memory info if available
|
||||
if torch.cuda.is_available():
|
||||
allocated = torch.cuda.memory_allocated()
|
||||
reserved = torch.cuda.memory_reserved()
|
||||
print(f"📊 Total GPU Memory Status:")
|
||||
print(f" Currently Allocated: {self.format_memory_size(allocated)}")
|
||||
print(f" Reserved by PyTorch: {self.format_memory_size(reserved)}")
|
||||
|
||||
return (quantized_model,)
|
||||
|
||||
def _apply_selective_quantization(self, model, use_asymmetric=False, quant_dtype="float32",
|
||||
use_int8_matmul=True, skip_input_blocks=False,
|
||||
skip_output_blocks=False):
|
||||
"""Apply quantization with selective block skipping for UNet"""
|
||||
quant_count = 0
|
||||
|
||||
def _quantize_module(module, prefix=""):
|
||||
nonlocal quant_count
|
||||
for name, child in module.named_children():
|
||||
full_name = f"{prefix}.{name}" if prefix else name
|
||||
|
||||
# Skip blocks based on user preference
|
||||
if skip_input_blocks and "input_blocks" in full_name:
|
||||
print(f"⏭️ Skipping {full_name} (input block)")
|
||||
_quantize_module(child, full_name)
|
||||
continue
|
||||
|
||||
if skip_output_blocks and "output_blocks" in full_name:
|
||||
print(f"⏭️ Skipping {full_name} (output block)")
|
||||
_quantize_module(child, full_name)
|
||||
continue
|
||||
|
||||
# Skip text encoder and CLIP-related modules
|
||||
if any(skip_name in full_name.lower() for skip_name in
|
||||
['text_encoder', 'clip', 'embedder', 'conditioner']):
|
||||
print(f"⏭️ Skipping {full_name} (text/conditioning module)")
|
||||
_quantize_module(child, full_name)
|
||||
continue
|
||||
|
||||
if isinstance(child, (nn.Linear, nn.Conv2d)):
|
||||
try:
|
||||
# Extract and quantize weights
|
||||
W = child.weight.data.float()
|
||||
qW, scale, zp = quantize_weight(W, use_asymmetric=use_asymmetric)
|
||||
|
||||
# Store original device info before removing weight
|
||||
original_device = child.weight.device
|
||||
|
||||
# Remove original weight parameter to save memory
|
||||
del child._parameters["weight"]
|
||||
|
||||
# Register quantized parameters as buffers (non-trainable)
|
||||
child.register_buffer("int8_weight", qW.to(original_device))
|
||||
child.register_buffer("scale", scale.to(original_device))
|
||||
if zp is not None:
|
||||
child.register_buffer("zero_point", zp.to(original_device))
|
||||
else:
|
||||
child.zero_point = None
|
||||
|
||||
# Configure optimization settings for this layer
|
||||
if isinstance(child, nn.Linear) and not use_asymmetric and use_int8_matmul:
|
||||
# Enable optimized matmul for symmetric quantized linear layers
|
||||
child._use_optimized_matmul = True
|
||||
# Transpose weight for optimized matmul layout
|
||||
child.int8_weight = child.int8_weight.transpose(0, 1).contiguous()
|
||||
# Adjust scale dimensions for matmul
|
||||
child.scale = child.scale.squeeze(-1)
|
||||
else:
|
||||
child._use_optimized_matmul = False
|
||||
|
||||
# Assign optimized forward function
|
||||
child.forward = make_optimized_quantized_forward(
|
||||
quant_dtype, use_int8_matmul
|
||||
).__get__(child)
|
||||
|
||||
quant_count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to quantize {full_name}: {str(e)}")
|
||||
|
||||
# Recursively process child modules
|
||||
_quantize_module(child, full_name)
|
||||
|
||||
_quantize_module(model)
|
||||
print(f"✅ Successfully quantized {quant_count} layers with selective patching")
|
||||
|
||||
# ------------------------- Node Registration -------------------------------
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CheckpointLoaderQuantized2": CheckpointLoaderQuantized2,
|
||||
"ModelQuantizationPatcher": ModelQuantizationPatcher,
|
||||
"UNetQuantizationPatcher": UNetQuantizationPatcher,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CheckpointLoaderQuantized2": "CFZ Checkpoint Loader (Optimized)",
|
||||
"ModelQuantizationPatcher": "CFZ Model Quantization Patcher",
|
||||
"UNetQuantizationPatcher": "CFZ UNet Quantization Patcher",
|
||||
}
|
||||
|
||||
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
|
||||
62
cfz/nodes/cfz_vae_loader.py
Normal file
62
cfz/nodes/cfz_vae_loader.py
Normal file
@ -0,0 +1,62 @@
|
||||
import torch
|
||||
import folder_paths
|
||||
from comfy import model_management
|
||||
from nodes import VAELoader
|
||||
|
||||
class CFZVAELoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"vae_name": (folder_paths.get_filename_list("vae"), ),
|
||||
"precision": (["fp32", "fp16", "bf16"], {"default": "fp32"}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("VAE",)
|
||||
FUNCTION = "load_vae"
|
||||
CATEGORY = "loaders"
|
||||
TITLE = "CFZ VAE Loader"
|
||||
|
||||
def load_vae(self, vae_name, precision):
|
||||
# Map precision to dtype
|
||||
dtype_map = {
|
||||
"fp32": torch.float32,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16
|
||||
}
|
||||
target_dtype = dtype_map[precision]
|
||||
|
||||
# Temporarily patch model_management functions
|
||||
original_should_use_bf16 = model_management.should_use_bf16
|
||||
original_should_use_fp16 = model_management.should_use_fp16
|
||||
|
||||
def custom_should_use_bf16(*args, **kwargs):
|
||||
return precision == "bf16"
|
||||
|
||||
def custom_should_use_fp16(*args, **kwargs):
|
||||
return precision == "fp16"
|
||||
|
||||
# Apply patches
|
||||
model_management.should_use_bf16 = custom_should_use_bf16
|
||||
model_management.should_use_fp16 = custom_should_use_fp16
|
||||
|
||||
try:
|
||||
# Load the VAE with patched precision functions
|
||||
vae_loader = VAELoader()
|
||||
vae = vae_loader.load_vae(vae_name)[0]
|
||||
print(f"CFZ VAE: Loaded with forced precision {precision}")
|
||||
return (vae,)
|
||||
finally:
|
||||
# Restore original functions
|
||||
model_management.should_use_bf16 = original_should_use_bf16
|
||||
model_management.should_use_fp16 = original_should_use_fp16
|
||||
|
||||
# Node mappings for ComfyUI
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"CFZVAELoader": CFZVAELoader
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CFZVAELoader": "CFZ VAE Loader"
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user