From 74545c2acdef6120d9841e5e4c901550bdc46ffc Mon Sep 17 00:00:00 2001 From: patientx Date: Tue, 2 Sep 2025 17:59:58 +0300 Subject: [PATCH] Add files via upload --- cfz/nodes/cfz_cudnn.toggle.py | 59 ++++ cfz/nodes/cfz_patcher.py | 543 ++++++++++++++++++++++++++++++++++ cfz/nodes/cfz_vae_loader.py | 62 ++++ 3 files changed, 664 insertions(+) create mode 100644 cfz/nodes/cfz_cudnn.toggle.py create mode 100644 cfz/nodes/cfz_patcher.py create mode 100644 cfz/nodes/cfz_vae_loader.py diff --git a/cfz/nodes/cfz_cudnn.toggle.py b/cfz/nodes/cfz_cudnn.toggle.py new file mode 100644 index 000000000..1e13d0659 --- /dev/null +++ b/cfz/nodes/cfz_cudnn.toggle.py @@ -0,0 +1,59 @@ +import torch + +class AnyType(str): + """A special class that is always equal in not equal comparisons. Credit to pythongosssss""" + + def __ne__(self, __value: object) -> bool: + return False + +anyType = AnyType("*") + +class CUDNNToggleAutoPassthrough: + @classmethod + def INPUT_TYPES(cls): + return { + "optional": { + "model": ("MODEL",), + "conditioning": ("CONDITIONING",), + "latent": ("LATENT",), + "audio": ("AUDIO",), + "image": ("IMAGE",), + "wan_model": ("WANVIDEOMODEL",), + "any_input": (anyType, {}), + }, + "required": { + "enable_cudnn": ("BOOLEAN", {"default": True}), + "cudnn_benchmark": ("BOOLEAN", {"default": False}), + }, + } + + RETURN_TYPES = ("MODEL", "CONDITIONING", "LATENT", "AUDIO", "IMAGE", "WANVIDEOMODEL", anyType, "BOOLEAN", "BOOLEAN") + RETURN_NAMES = ("model", "conditioning", "latent", "audio", "image", "wan_model", "any_output", "prev_cudnn", "prev_benchmark") + FUNCTION = "toggle" + CATEGORY = "utils" + + def toggle(self, enable_cudnn, cudnn_benchmark, any_input=None, wan_model=None, model=None, conditioning=None, latent=None, audio=None, image=None): + prev_cudnn = torch.backends.cudnn.enabled + prev_benchmark = torch.backends.cudnn.benchmark + torch.backends.cudnn.enabled = enable_cudnn + torch.backends.cudnn.benchmark = cudnn_benchmark + if enable_cudnn != prev_cudnn: + print(f"[CUDNN_TOGGLE] torch.backends.cudnn.enabled set to {enable_cudnn} (was {prev_cudnn})") + else: + print(f"[CUDNN_TOGGLE] torch.backends.cudnn.enabled still set to {enable_cudnn}") + + if cudnn_benchmark != prev_benchmark: + print(f"[CUDNN_TOGGLE] torch.backends.cudnn.benchmark set to {cudnn_benchmark} (was {prev_benchmark})") + else: + print(f"[CUDNN_TOGGLE] torch.backends.cudnn.benchmark still set to {cudnn_benchmark}") + + return_tuple = (model, conditioning, latent, audio, image, wan_model, any_input, prev_cudnn, prev_benchmark) + return return_tuple + +NODE_CLASS_MAPPINGS = { + "CUDNNToggleAutoPassthrough": CUDNNToggleAutoPassthrough +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "CUDNNToggleAutoPassthrough": "CFZ CUDNN Toggle" +} diff --git a/cfz/nodes/cfz_patcher.py b/cfz/nodes/cfz_patcher.py new file mode 100644 index 000000000..6a615174b --- /dev/null +++ b/cfz/nodes/cfz_patcher.py @@ -0,0 +1,543 @@ +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from comfy.sd import load_checkpoint_guess_config, load_checkpoint +from comfy.model_patcher import ModelPatcher +import folder_paths + +# ------------------------ Optimized Quantization Logic ------------------------- +def quantize_input_for_int8_matmul(input_tensor, weight_scale): + """Quantize input tensor for optimized int8 matrix multiplication""" + # Calculate input scale per batch/sequence dimension + input_scale = input_tensor.abs().amax(dim=-1, keepdim=True) / 127.0 + input_scale = torch.clamp(input_scale, min=1e-8) + + # Quantize input to int8 + quantized_input = torch.clamp( + (input_tensor / input_scale).round(), -128, 127 + ).to(torch.int8) + + # Combine input and weight scales + combined_scale = input_scale * weight_scale + + # Flatten tensors for matrix multiplication if needed + original_shape = input_tensor.shape + if input_tensor.dim() > 2: + quantized_input = quantized_input.flatten(0, -2).contiguous() + combined_scale = combined_scale.flatten(0, -2).contiguous() + # Ensure scale precision for accurate computation + if combined_scale.dtype == torch.float16: + combined_scale = combined_scale.to(torch.float32) + + return quantized_input, combined_scale, original_shape + +def optimized_int8_matmul(input_tensor, quantized_weight, weight_scale, bias=None): + """Optimized int8 matrix multiplication using torch._int_mm""" + batch_size = input_tensor.numel() // input_tensor.shape[-1] + + # Performance threshold: only use optimized path for larger matrices + # This prevents overhead from dominating small computations + if batch_size >= 32 and input_tensor.shape[-1] >= 32: + # Quantize input tensor for int8 computation + q_input, combined_scale, orig_shape = quantize_input_for_int8_matmul( + input_tensor, weight_scale + ) + + # Perform optimized int8 matrix multiplication + # This is significantly faster than standard floating-point operations + result = torch._int_mm(q_input, quantized_weight) + + # Dequantize result back to floating point + result = result.to(combined_scale.dtype) * combined_scale + + # Reshape result back to original input dimensions + if len(orig_shape) > 2: + new_shape = list(orig_shape[:-1]) + [quantized_weight.shape[-1]] + result = result.reshape(new_shape) + + # Add bias if present + if bias is not None: + result = result + bias + + return result + else: + # Fall back to standard dequantization for small matrices + # This avoids quantization overhead when it's not beneficial + dequantized_weight = quantized_weight.to(input_tensor.dtype) * weight_scale + return F.linear(input_tensor, dequantized_weight, bias) + +def make_optimized_quantized_forward(quant_dtype="float32", use_int8_matmul=True): + """Create an optimized quantized forward function for neural network layers""" + def forward(self, x): + # Determine computation precision + dtype = torch.float32 if quant_dtype == "float32" else torch.float16 + + # Get input device for consistent placement + device = x.device + + # Move quantized weights and scales to input device AND dtype + qW = self.int8_weight.to(device) + scale = self.scale.to(device, dtype=dtype) + + # Handle zero point for asymmetric quantization + if hasattr(self, 'zero_point') and self.zero_point is not None: + zp = self.zero_point.to(device, dtype=dtype) + else: + zp = None + + # Ensure input is in correct precision + x = x.to(dtype) + + # Prepare bias if present - ENSURE IT'S ON THE CORRECT DEVICE + bias = None + if self.bias is not None: + bias = self.bias.to(device, dtype=dtype) + + # Apply LoRA adaptation if present (before main computation for better accuracy) + lora_output = None + if hasattr(self, "lora_down") and hasattr(self, "lora_up") and hasattr(self, "lora_alpha"): + # Ensure LoRA weights are on correct device + lora_down = self.lora_down.to(device) + lora_up = self.lora_up.to(device) + lora_output = lora_up(lora_down(x)) * self.lora_alpha + + # Choose computation path based on layer type and optimization settings + if isinstance(self, nn.Linear): + # Linear layers can use optimized int8 matmul + if (use_int8_matmul and zp is None and + hasattr(self, '_use_optimized_matmul') and self._use_optimized_matmul): + # Use optimized path (only for symmetric quantization) + result = optimized_int8_matmul(x, qW, scale, bias) + else: + # Standard dequantization path + if zp is not None: + # Asymmetric quantization: subtract zero point then scale + W = (qW.to(dtype) - zp) * scale + else: + # Symmetric quantization: just scale + W = qW.to(dtype) * scale + result = F.linear(x, W, bias) + + elif isinstance(self, nn.Conv2d): + # Convolution layers use standard dequantization + if zp is not None: + W = (qW.to(dtype) - zp) * scale + else: + W = qW.to(dtype) * scale + result = F.conv2d(x, W, bias, self.stride, self.padding, self.dilation, self.groups) + + else: + # Fallback for unsupported layer types + return x + + # Add LoRA output if computed + if lora_output is not None: + result = result + lora_output + + return result + + return forward + +def quantize_weight(weight: torch.Tensor, num_bits=8, use_asymmetric=False): + """Quantize weights with support for both symmetric and asymmetric quantization""" + # Determine reduction dimensions (preserve output channels) + reduce_dim = 1 if weight.ndim == 2 else [i for i in range(weight.ndim) if i != 0] + + if use_asymmetric: + # Asymmetric quantization: use full range [0, 255] for uint8 + min_val = weight.amin(dim=reduce_dim, keepdim=True) + max_val = weight.amax(dim=reduce_dim, keepdim=True) + scale = torch.clamp((max_val - min_val) / 255.0, min=1e-8) + zero_point = torch.clamp((-min_val / scale).round(), 0, 255).to(torch.uint8) + qweight = torch.clamp((weight / scale + zero_point).round(), 0, 255).to(torch.uint8) + else: + # Symmetric quantization: use range [-127, 127] for int8 + w_max = weight.abs().amax(dim=reduce_dim, keepdim=True) + scale = torch.clamp(w_max / 127.0, min=1e-8) + qweight = torch.clamp((weight / scale).round(), -128, 127).to(torch.int8) + zero_point = None + + return qweight, scale.to(torch.float16), zero_point + +def apply_optimized_quantization(model, use_asymmetric=False, quant_dtype="float32", + use_int8_matmul=True): + """Apply quantization with optimized inference paths to a neural network model""" + quant_count = 0 + + def _quantize_module(module, prefix=""): + nonlocal quant_count + for name, child in module.named_children(): + full_name = f"{prefix}.{name}" if prefix else name + + # Skip text encoder and CLIP-related modules to avoid conditioning issues + if any(skip_name in full_name.lower() for skip_name in + ['text_encoder', 'clip', 'embedder', 'conditioner']): + print(f"⏭️ Skipping {full_name} (text/conditioning module)") + _quantize_module(child, full_name) + continue + + + if isinstance(child, (nn.Linear, nn.Conv2d)): + try: + # Extract and quantize weights + W = child.weight.data.float() + qW, scale, zp = quantize_weight(W, use_asymmetric=use_asymmetric) + + # Store original device info before removing weight + original_device = child.weight.device + + # Remove original weight parameter to save memory + del child._parameters["weight"] + + # Register quantized parameters as buffers (non-trainable) + # Keep them on CPU initially to save GPU memory + child.register_buffer("int8_weight", qW.to(original_device)) + child.register_buffer("scale", scale.to(original_device)) + if zp is not None: + child.register_buffer("zero_point", zp.to(original_device)) + else: + child.zero_point = None + + # Configure optimization settings for this layer + if isinstance(child, nn.Linear) and not use_asymmetric and use_int8_matmul: + # Enable optimized matmul for symmetric quantized linear layers + child._use_optimized_matmul = True + # Transpose weight for optimized matmul layout + child.int8_weight = child.int8_weight.transpose(0, 1).contiguous() + # Adjust scale dimensions for matmul + child.scale = child.scale.squeeze(-1) + else: + child._use_optimized_matmul = False + + # Assign optimized forward function + child.forward = make_optimized_quantized_forward( + quant_dtype, use_int8_matmul + ).__get__(child) + + quant_count += 1 + opt_status = "optimized" if child._use_optimized_matmul else "standard" + # print(f"✅ Quantized {full_name} ({opt_status})") + + except Exception as e: + print(f"❌ Failed to quantize {full_name}: {str(e)}") + + # Recursively process child modules + _quantize_module(child, full_name) + + _quantize_module(model) + print(f"✅ Successfully quantized {quant_count} layers with optimized inference") + return model + +# ---------------------- ComfyUI Node Implementations ------------------------ + +class CheckpointLoaderQuantized2: + """Original checkpoint loader with quantization""" + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), + "enable_quant": ("BOOLEAN", {"default": True}), + "use_asymmetric": ("BOOLEAN", {"default": False}), + "quant_dtype": (["float32", "float16"], {"default": "float32"}), + "use_int8_matmul": ("BOOLEAN", {"default": True}), + } + } + + RETURN_TYPES = ("MODEL", "CLIP", "VAE") + FUNCTION = "load_quantized" + CATEGORY = "Loaders (Quantized)" + OUTPUT_NODE = False + + def load_quantized(self, ckpt_name, enable_quant, use_asymmetric, quant_dtype, + use_int8_matmul): + """Load and optionally quantize a checkpoint with optimized inference""" + ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) + + if not os.path.exists(ckpt_path): + raise FileNotFoundError(f"Checkpoint {ckpt_name} not found at {ckpt_path}") + + # Load checkpoint using ComfyUI's standard loading mechanism + model_patcher, clip, vae, _ = load_checkpoint_guess_config( + ckpt_path, + output_vae=True, + output_clip=True, + embedding_directory=folder_paths.get_folder_paths("embeddings") + ) + + if enable_quant: + # Determine quantization configuration + quant_mode = "Asymmetric" if use_asymmetric else "Symmetric" + matmul_mode = "Optimized Int8" if use_int8_matmul and not use_asymmetric else "Standard" + + print(f"🔧 Applying {quant_mode} 8-bit quantization to {ckpt_name}") + print(f" MatMul: {matmul_mode}, Forward: Optimized (dtype={quant_dtype})") + + # Apply quantization with optimizations + apply_optimized_quantization( + model_patcher.model, + use_asymmetric=use_asymmetric, + quant_dtype=quant_dtype, + use_int8_matmul=use_int8_matmul + ) + else: + print(f"🔧 Loading {ckpt_name} without quantization") + + return (model_patcher, clip, vae) + + +class ModelQuantizationPatcher: + """Quantization patcher that can be applied to any model in the workflow""" + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model": ("MODEL",), + "use_asymmetric": ("BOOLEAN", {"default": False}), + "quant_dtype": (["float32", "float16"], {"default": "float32"}), + "use_int8_matmul": ("BOOLEAN", {"default": True}), + } + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch_model" + CATEGORY = "Model Patching" + OUTPUT_NODE = False + + def patch_model(self, model, use_asymmetric, quant_dtype, use_int8_matmul): + """Apply quantization to an existing model""" + # Clone the model to avoid modifying the original + import copy + quantized_model = copy.deepcopy(model) + + # Determine quantization configuration + quant_mode = "Asymmetric" if use_asymmetric else "Symmetric" + matmul_mode = "Optimized Int8" if use_int8_matmul and not use_asymmetric else "Standard" + + print(f"🔧 Applying {quant_mode} 8-bit quantization to model") + print(f" MatMul: {matmul_mode}, Forward: Optimized (dtype={quant_dtype})") + + # Apply quantization with optimizations + apply_optimized_quantization( + quantized_model.model, + use_asymmetric=use_asymmetric, + quant_dtype=quant_dtype, + use_int8_matmul=use_int8_matmul + ) + + return (quantized_model,) + + +class UNetQuantizationPatcher: + """Specialized quantization patcher for UNet models loaded separately""" + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model": ("MODEL",), + "use_asymmetric": ("BOOLEAN", {"default": False}), + "quant_dtype": (["float32", "float16"], {"default": "float32"}), + "use_int8_matmul": ("BOOLEAN", {"default": True}), + "skip_input_blocks": ("BOOLEAN", {"default": False}), + "skip_output_blocks": ("BOOLEAN", {"default": False}), + "show_memory_usage": ("BOOLEAN", {"default": True}), + } + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch_unet" + CATEGORY = "Model Patching" + OUTPUT_NODE = False + + def get_model_memory_usage(self, model, force_calculation=False): + """Calculate memory usage of model parameters (CPU + GPU)""" + total_memory = 0 + param_count = 0 + gpu_memory = 0 + + # Count all parameters (CPU + GPU) + for param in model.parameters(): + memory_bytes = param.data.element_size() * param.data.nelement() + total_memory += memory_bytes + param_count += param.data.nelement() + + if param.data.is_cuda: + gpu_memory += memory_bytes + + # Also check for quantized buffers + for name, buffer in model.named_buffers(): + if 'int8_weight' in name or 'scale' in name or 'zero_point' in name: + memory_bytes = buffer.element_size() * buffer.nelement() + total_memory += memory_bytes + + if buffer.is_cuda: + gpu_memory += memory_bytes + + # If force_calculation is True and nothing on GPU, return total memory as estimate + if force_calculation and gpu_memory == 0: + return total_memory, param_count, total_memory + + return total_memory, param_count, gpu_memory + + def format_memory_size(self, bytes_size): + """Format memory size in human readable format""" + for unit in ['B', 'KB', 'MB', 'GB']: + if bytes_size < 1024.0: + return f"{bytes_size:.2f} {unit}" + bytes_size /= 1024.0 + return f"{bytes_size:.2f} TB" + + def patch_unet(self, model, use_asymmetric, quant_dtype, use_int8_matmul, + skip_input_blocks, skip_output_blocks, show_memory_usage): + """Apply selective quantization to UNet model with block-level control""" + import copy + + # Measure original memory usage + if show_memory_usage: + original_memory, original_params, original_gpu = self.get_model_memory_usage(model.model, force_calculation=True) + print(f"📊 Original Model Memory Usage:") + print(f" Parameters: {original_params:,}") + print(f" Total Size: {self.format_memory_size(original_memory)}") + if original_gpu > 0: + print(f" GPU Memory: {self.format_memory_size(original_gpu)}") + else: + print(f" GPU Memory: Not loaded (will use ~{self.format_memory_size(original_memory)} when loaded)") + + quantized_model = copy.deepcopy(model) + + # Determine quantization configuration + quant_mode = "Asymmetric" if use_asymmetric else "Symmetric" + matmul_mode = "Optimized Int8" if use_int8_matmul and not use_asymmetric else "Standard" + + print(f"🔧 Applying {quant_mode} 8-bit quantization to UNet") + print(f" MatMul: {matmul_mode}, Forward: Optimized (dtype={quant_dtype})") + + if skip_input_blocks or skip_output_blocks: + print(f" Skipping: Input blocks={skip_input_blocks}, Output blocks={skip_output_blocks}") + + # Apply quantization with selective skipping + self._apply_selective_quantization( + quantized_model.model, + use_asymmetric=use_asymmetric, + quant_dtype=quant_dtype, + use_int8_matmul=use_int8_matmul, + skip_input_blocks=skip_input_blocks, + skip_output_blocks=skip_output_blocks + ) + + # Measure quantized memory usage + if show_memory_usage: + quantized_memory, quantized_params, quantized_gpu = self.get_model_memory_usage(quantized_model.model, force_calculation=True) + memory_saved = original_memory - quantized_memory + memory_reduction_pct = (memory_saved / original_memory) * 100 if original_memory > 0 else 0 + + print(f"📊 Quantized Model Memory Usage:") + print(f" Parameters: {quantized_params:,}") + print(f" Total Size: {self.format_memory_size(quantized_memory)}") + if quantized_gpu > 0: + print(f" GPU Memory: {self.format_memory_size(quantized_gpu)}") + else: + print(f" GPU Memory: Not loaded (will use ~{self.format_memory_size(quantized_memory)} when loaded)") + print(f" Memory Saved: {self.format_memory_size(memory_saved)} ({memory_reduction_pct:.1f}%)") + + # Show CUDA memory info if available + if torch.cuda.is_available(): + allocated = torch.cuda.memory_allocated() + reserved = torch.cuda.memory_reserved() + print(f"📊 Total GPU Memory Status:") + print(f" Currently Allocated: {self.format_memory_size(allocated)}") + print(f" Reserved by PyTorch: {self.format_memory_size(reserved)}") + + return (quantized_model,) + + def _apply_selective_quantization(self, model, use_asymmetric=False, quant_dtype="float32", + use_int8_matmul=True, skip_input_blocks=False, + skip_output_blocks=False): + """Apply quantization with selective block skipping for UNet""" + quant_count = 0 + + def _quantize_module(module, prefix=""): + nonlocal quant_count + for name, child in module.named_children(): + full_name = f"{prefix}.{name}" if prefix else name + + # Skip blocks based on user preference + if skip_input_blocks and "input_blocks" in full_name: + print(f"⏭️ Skipping {full_name} (input block)") + _quantize_module(child, full_name) + continue + + if skip_output_blocks and "output_blocks" in full_name: + print(f"⏭️ Skipping {full_name} (output block)") + _quantize_module(child, full_name) + continue + + # Skip text encoder and CLIP-related modules + if any(skip_name in full_name.lower() for skip_name in + ['text_encoder', 'clip', 'embedder', 'conditioner']): + print(f"⏭️ Skipping {full_name} (text/conditioning module)") + _quantize_module(child, full_name) + continue + + if isinstance(child, (nn.Linear, nn.Conv2d)): + try: + # Extract and quantize weights + W = child.weight.data.float() + qW, scale, zp = quantize_weight(W, use_asymmetric=use_asymmetric) + + # Store original device info before removing weight + original_device = child.weight.device + + # Remove original weight parameter to save memory + del child._parameters["weight"] + + # Register quantized parameters as buffers (non-trainable) + child.register_buffer("int8_weight", qW.to(original_device)) + child.register_buffer("scale", scale.to(original_device)) + if zp is not None: + child.register_buffer("zero_point", zp.to(original_device)) + else: + child.zero_point = None + + # Configure optimization settings for this layer + if isinstance(child, nn.Linear) and not use_asymmetric and use_int8_matmul: + # Enable optimized matmul for symmetric quantized linear layers + child._use_optimized_matmul = True + # Transpose weight for optimized matmul layout + child.int8_weight = child.int8_weight.transpose(0, 1).contiguous() + # Adjust scale dimensions for matmul + child.scale = child.scale.squeeze(-1) + else: + child._use_optimized_matmul = False + + # Assign optimized forward function + child.forward = make_optimized_quantized_forward( + quant_dtype, use_int8_matmul + ).__get__(child) + + quant_count += 1 + + except Exception as e: + print(f"❌ Failed to quantize {full_name}: {str(e)}") + + # Recursively process child modules + _quantize_module(child, full_name) + + _quantize_module(model) + print(f"✅ Successfully quantized {quant_count} layers with selective patching") + +# ------------------------- Node Registration ------------------------------- +NODE_CLASS_MAPPINGS = { + "CheckpointLoaderQuantized2": CheckpointLoaderQuantized2, + "ModelQuantizationPatcher": ModelQuantizationPatcher, + "UNetQuantizationPatcher": UNetQuantizationPatcher, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "CheckpointLoaderQuantized2": "CFZ Checkpoint Loader (Optimized)", + "ModelQuantizationPatcher": "CFZ Model Quantization Patcher", + "UNetQuantizationPatcher": "CFZ UNet Quantization Patcher", +} + +__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] diff --git a/cfz/nodes/cfz_vae_loader.py b/cfz/nodes/cfz_vae_loader.py new file mode 100644 index 000000000..b1023e0e0 --- /dev/null +++ b/cfz/nodes/cfz_vae_loader.py @@ -0,0 +1,62 @@ +import torch +import folder_paths +from comfy import model_management +from nodes import VAELoader + +class CFZVAELoader: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "vae_name": (folder_paths.get_filename_list("vae"), ), + "precision": (["fp32", "fp16", "bf16"], {"default": "fp32"}), + } + } + + RETURN_TYPES = ("VAE",) + FUNCTION = "load_vae" + CATEGORY = "loaders" + TITLE = "CFZ VAE Loader" + + def load_vae(self, vae_name, precision): + # Map precision to dtype + dtype_map = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16 + } + target_dtype = dtype_map[precision] + + # Temporarily patch model_management functions + original_should_use_bf16 = model_management.should_use_bf16 + original_should_use_fp16 = model_management.should_use_fp16 + + def custom_should_use_bf16(*args, **kwargs): + return precision == "bf16" + + def custom_should_use_fp16(*args, **kwargs): + return precision == "fp16" + + # Apply patches + model_management.should_use_bf16 = custom_should_use_bf16 + model_management.should_use_fp16 = custom_should_use_fp16 + + try: + # Load the VAE with patched precision functions + vae_loader = VAELoader() + vae = vae_loader.load_vae(vae_name)[0] + print(f"CFZ VAE: Loaded with forced precision {precision}") + return (vae,) + finally: + # Restore original functions + model_management.should_use_bf16 = original_should_use_bf16 + model_management.should_use_fp16 = original_should_use_fp16 + +# Node mappings for ComfyUI +NODE_CLASS_MAPPINGS = { + "CFZVAELoader": CFZVAELoader +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "CFZVAELoader": "CFZ VAE Loader" +} \ No newline at end of file