diff --git a/cfz/cfz_patcher.py b/cfz/cfz_patcher.py index 62e370640..6122d5838 100644 --- a/cfz/cfz_patcher.py +++ b/cfz/cfz_patcher.py @@ -350,25 +350,35 @@ class UNetQuantizationPatcher: CATEGORY = "Model Patching" OUTPUT_NODE = False - def get_model_memory_usage(self, model): - """Calculate actual memory usage of model parameters""" + def get_model_memory_usage(self, model, force_calculation=False): + """Calculate memory usage of model parameters (CPU + GPU)""" total_memory = 0 param_count = 0 + gpu_memory = 0 + # Count all parameters (CPU + GPU) for param in model.parameters(): + memory_bytes = param.data.element_size() * param.data.nelement() + total_memory += memory_bytes + param_count += param.data.nelement() + if param.data.is_cuda: - # Get actual memory usage on GPU - memory_bytes = param.data.element_size() * param.data.nelement() - total_memory += memory_bytes - param_count += param.data.nelement() + gpu_memory += memory_bytes # Also check for quantized buffers for name, buffer in model.named_buffers(): - if buffer.is_cuda and ('int8_weight' in name or 'scale' in name or 'zero_point' in name): + if 'int8_weight' in name or 'scale' in name or 'zero_point' in name: memory_bytes = buffer.element_size() * buffer.nelement() total_memory += memory_bytes + + if buffer.is_cuda: + gpu_memory += memory_bytes - return total_memory, param_count + # If force_calculation is True and nothing on GPU, return total memory as estimate + if force_calculation and gpu_memory == 0: + return total_memory, param_count, total_memory + + return total_memory, param_count, gpu_memory def format_memory_size(self, bytes_size): """Format memory size in human readable format""" @@ -385,10 +395,14 @@ class UNetQuantizationPatcher: # Measure original memory usage if show_memory_usage: - original_memory, original_params = self.get_model_memory_usage(model.model) + original_memory, original_params, original_gpu = self.get_model_memory_usage(model.model, force_calculation=True) print(f"📊 Original Model Memory Usage:") print(f" Parameters: {original_params:,}") - print(f" VRAM Usage: {self.format_memory_size(original_memory)}") + print(f" Total Size: {self.format_memory_size(original_memory)}") + if original_gpu > 0: + print(f" GPU Memory: {self.format_memory_size(original_gpu)}") + else: + print(f" GPU Memory: Not loaded (will use ~{self.format_memory_size(original_memory)} when loaded)") quantized_model = copy.deepcopy(model) @@ -414,27 +428,26 @@ class UNetQuantizationPatcher: # Measure quantized memory usage if show_memory_usage: - # Force GPU memory allocation by moving model to device if needed - if torch.cuda.is_available(): - device = next(quantized_model.model.parameters()).device - quantized_model.model.to(device) - - quantized_memory, quantized_params = self.get_model_memory_usage(quantized_model.model) + quantized_memory, quantized_params, quantized_gpu = self.get_model_memory_usage(quantized_model.model, force_calculation=True) memory_saved = original_memory - quantized_memory memory_reduction_pct = (memory_saved / original_memory) * 100 if original_memory > 0 else 0 print(f"📊 Quantized Model Memory Usage:") print(f" Parameters: {quantized_params:,}") - print(f" VRAM Usage: {self.format_memory_size(quantized_memory)}") + print(f" Total Size: {self.format_memory_size(quantized_memory)}") + if quantized_gpu > 0: + print(f" GPU Memory: {self.format_memory_size(quantized_gpu)}") + else: + print(f" GPU Memory: Not loaded (will use ~{self.format_memory_size(quantized_memory)} when loaded)") print(f" Memory Saved: {self.format_memory_size(memory_saved)} ({memory_reduction_pct:.1f}%)") # Show CUDA memory info if available if torch.cuda.is_available(): allocated = torch.cuda.memory_allocated() reserved = torch.cuda.memory_reserved() - print(f"📊 Total GPU Memory:") - print(f" Allocated: {self.format_memory_size(allocated)}") - print(f" Reserved: {self.format_memory_size(reserved)}") + print(f"📊 Total GPU Memory Status:") + print(f" Currently Allocated: {self.format_memory_size(allocated)}") + print(f" Reserved by PyTorch: {self.format_memory_size(reserved)}") return (quantized_model,) @@ -527,4 +540,4 @@ NODE_DISPLAY_NAME_MAPPINGS = { "UNetQuantizationPatcher": "CFZ UNet Quantization Patcher", } -__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] \ No newline at end of file +__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']