From 59c482fd77c27a65d743c6140da75e0abdd91f41 Mon Sep 17 00:00:00 2001 From: lspindler Date: Thu, 16 Oct 2025 16:07:43 +0200 Subject: [PATCH 1/7] Implement mixed precision operations with a registry design and metadate for quant spec in checkpoint. --- comfy/model_base.py | 10 +- comfy/model_detection.py | 125 +++++++++ comfy/ops.py | 484 ++++++++++++++++++++++++++++++++- comfy/supported_models_base.py | 1 + 4 files changed, 618 insertions(+), 2 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 8274c7dea..7b4651f8e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -134,7 +134,7 @@ class BaseModel(torch.nn.Module): if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: fp8 = model_config.optimizations.get("fp8", False) - operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8) + operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config) else: operations = model_config.custom_operations self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) @@ -326,6 +326,14 @@ class BaseModel(torch.nn.Module): if self.model_config.scaled_fp8 is not None: unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8) + + # Save mixed precision metadata + if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config: + metadata = { + "format_version": "1.0", + "layers": self.model_config.layer_quant_config + } + unet_state_dict["_quantization_metadata"] = metadata unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 7677617c0..7ce9aaa9a 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -6,6 +6,125 @@ import math import logging import torch + +# ============================================================================== +# Quantization Detection Functions +# ============================================================================== + +def normalize_layer_name(full_key, known_prefixes): + """ + Strip model prefix and parameter suffix from a state dict key. + + Args: + full_key: Full state dict key (e.g., "model.diffusion_model.layer1.weight") + known_prefixes: List of known model prefixes to strip + + Returns: + Normalized layer name (e.g., "layer1") + """ + name = full_key + + # Strip model prefix + for prefix in known_prefixes: + if name.startswith(prefix): + name = name[len(prefix):] + break + + # Remove parameter suffix + for suffix in [".weight", ".bias", ".scale_weight", ".scale_input"]: + if name.endswith(suffix): + name = name[:-len(suffix)] + break + + return name + + +def detect_layer_quantization(state_dict, prefix="model.diffusion_model."): + """ + Detect per-layer quantization configuration from state dict. + + Detection priority: + 1. Check for _quantization_metadata key (new format) + 2. Check for scaled_fp8 key (legacy format - return None) + 3. Check for per-layer scale_weight patterns (mixed detection) + 4. No quantization detected (return None) + + Args: + state_dict: Model state dictionary + prefix: Key prefix for model layers + + Returns: + Dict mapping layer names to quantization configs, or None for legacy/no quantization. + + Example return value: + { + "input_blocks.5.1.transformer_blocks.0.attn1.to_q": { + "format": "fp8_e4m3fn_scaled", + "params": {"use_fp8_matmul": True} + }, + "middle_block.1.transformer_blocks.0.attn2.to_k": { + "format": "fp8_e5m2_scaled", + "params": {"use_fp8_matmul": True} + } + } + """ + + # 1. Check for new metadata format + metadata_key = f"{prefix}_quantization_metadata" + if metadata_key in state_dict: + try: + metadata = state_dict.pop(metadata_key) + if isinstance(metadata, dict) and "layers" in metadata: + logging.info(f"Found quantization metadata (version {metadata.get('format_version', 'unknown')})") + return metadata["layers"] + else: + logging.warning(f"Invalid quantization metadata format, ignoring") + except Exception as e: + logging.error(f"Failed to parse quantization metadata: {e}") + return None + + # 2. Check for legacy scaled_fp8 marker + # If present, return None to use legacy code path + scaled_fp8_key = f"{prefix}scaled_fp8" + if scaled_fp8_key in state_dict: + logging.debug("Detected legacy scaled_fp8 format, using legacy code path") + return None + + # 3. Check for per-layer scale patterns (mixed precision without metadata) + # Look for layers that have scale_weight but not all layers have it + known_prefixes = [prefix] + layer_configs = {} + layers_with_scale = set() + layers_with_weight = set() + + for key in state_dict.keys(): + if key.startswith(prefix): + if key.endswith(".scale_weight"): + layer_name = normalize_layer_name(key, known_prefixes) + layers_with_scale.add(layer_name) + # Detect format based on weight dtype + weight_key = f"{prefix}{layer_name}.weight" + if weight_key in state_dict: + weight_dtype = state_dict[weight_key].dtype + if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + format_name = "fp8_e4m3fn_scaled" if weight_dtype == torch.float8_e4m3fn else "fp8_e5m2_scaled" + layer_configs[layer_name] = { + "format": format_name, + "params": {"use_fp8_matmul": True} + } + elif key.endswith(".weight") and not key.endswith(".scale_weight"): + layer_name = normalize_layer_name(key, known_prefixes) + layers_with_weight.add(layer_name) + + # If we found scale_weight on some but not all layers, it's mixed precision + if layer_configs and len(layers_with_scale) < len(layers_with_weight): + logging.info(f"Detected mixed precision via scale patterns: {len(layers_with_scale)} quantized layers, {len(layers_with_weight)} total layers") + return layer_configs + + # 4. No quantization detected + return None + + def count_blocks(state_dict_keys, prefix_string): count = 0 while True: @@ -701,6 +820,12 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal else: model_config.optimizations["fp8"] = True + # Detect per-layer quantization (mixed precision) + layer_quant_config = detect_layer_quantization(state_dict, unet_key_prefix) + if layer_quant_config: + model_config.layer_quant_config = layer_quant_config + logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized") + return model_config def unet_prefix_from_state_dict(state_dict): diff --git a/comfy/ops.py b/comfy/ops.py index b2096b40e..7ce7d3293 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -458,7 +458,457 @@ if CUBLAS_IS_AVAILABLE: def forward(self, *args, **kwargs): return super().forward(*args, **kwargs) -def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None): + +# ============================================================================== +# Quantization Format Registry System +# ============================================================================== + +class QuantFormatHandler: + """ + Base class for all quantization format handlers. + + A handler encapsulates the logic for a specific quantization format + (e.g., FP8 scaled, MX formats) and manages the quantization + parameters and forward pass for quantized layers. + """ + + def __init__(self, layer, **config): + """ + Initialize handler for a specific layer. + + Args: + layer: The nn.Module layer (Linear, Conv2d, etc.) + **config: Format-specific configuration + """ + self.layer = layer + self.config = config + + def setup_parameters(self): + """ + Initialize quantization parameters on the layer. + Called during layer construction or load_state_dict. + + Subclasses should create parameters like scale_weight, scale_input, etc. + and attach them to self.layer. + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement setup_parameters()") + + def forward(self, *args, **kwargs): + """ + Execute quantized forward pass. + + Signature matches the layer's expected forward pass. + Handler accesses layer parameters via self.layer (weight, bias, etc.) + + Args: + *args: Positional arguments matching layer forward signature + **kwargs: Keyword arguments matching layer forward signature + + Returns: + Layer output tensor + + Examples: + Linear: forward(input) + Conv2d: forward(input) + GroupNorm: forward(input) + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement forward()") + + def load_state_dict(self, state_dict, prefix): + """ + Load quantization parameters from state dict. + + Args: + state_dict: State dictionary + prefix: Key prefix for this layer (e.g., "model.diffusion_model.layer1.") + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement load_state_dict()") + + def state_dict(self, prefix): + """ + Save quantization parameters to state dict. + + Args: + prefix: Key prefix for this layer + + Returns: + Dictionary of quantization parameters with full keys + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement state_dict()") + + def convert_weight(self, weight, inplace=False): + """ + Convert weight from quantized to full precision (dequantize). + + Args: + weight: Quantized weight tensor + inplace: Whether to modify in-place + + Returns: + Dequantized weight tensor + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement convert_weight()") + + def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False): + """ + Convert and set weight from full precision to quantized. + + Args: + weight: Full precision weight tensor + inplace_update: Whether to update layer weight in-place + seed: Random seed for stochastic rounding + return_weight: If True, return quantized weight without setting + + Returns: + Quantized weight if return_weight=True, else None + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement set_weight()") + + +class QuantFormatRegistry: + """ + Global registry for quantization formats. + + Formats are registered with a unique name and handler class. + Custom formats can be registered by custom nodes. + """ + + _formats = {} + + @classmethod + def register(cls, name, handler_class, **default_config): + """ + Register a new quantization format. + + Args: + name: Unique format identifier (e.g., "fp8_e4m3fn_scaled") + handler_class: Handler class implementing QuantFormatHandler + **default_config: Default configuration parameters + + Example: + QuantFormatRegistry.register( + "fp8_e4m3fn_scaled", + handler_class=FP8ScaledHandler, + base_dtype=torch.float8_e4m3fn, + quantize_activation=False, + use_fp8_matmul=True, + ) + """ + if not issubclass(handler_class, QuantFormatHandler): + raise TypeError(f"handler_class must be a subclass of QuantFormatHandler, got {handler_class}") + + cls._formats[name] = { + "handler": handler_class, + "config": default_config.copy() + } + logging.debug(f"Registered quantization format: {name}") + + @classmethod + def get(cls, name, **override_config): + """ + Get format info with optional config overrides. + + Args: + name: Format identifier + **override_config: Configuration overrides + + Returns: + Dict with 'handler' (class) and 'config' (dict) keys + + Raises: + ValueError: If format name not registered + """ + if name not in cls._formats: + available = ", ".join(cls._formats.keys()) if cls._formats else "none" + raise ValueError(f"Unknown quantization format: '{name}'. Available formats: {available}") + + format_info = cls._formats[name].copy() + # Merge override_config into default config + config = format_info["config"].copy() + config.update(override_config) + format_info["config"] = config + return format_info + + @classmethod + def list_formats(cls): + """List all registered format names""" + return list(cls._formats.keys()) + + @classmethod + def is_registered(cls, name): + """Check if a format is registered""" + return name in cls._formats + + +class FP8ScaledHandler(QuantFormatHandler): + """ + Handler for FP8 quantization with per-tensor scaling. + + Supports both weight-only and weight+activation quantization. + Compatible with existing fp8_linear implementation. + """ + + def setup_parameters(self): + """Initialize scale_weight and optionally scale_input""" + device = self.layer.weight.device + dtype = torch.float32 + + # Always have scale_weight for FP8 + if not hasattr(self.layer, 'scale_weight') or self.layer.scale_weight is None: + self.layer.scale_weight = torch.nn.Parameter( + torch.ones((), device=device, dtype=dtype), + requires_grad=False + ) + + # scale_input is optional (for activation quantization) + if self.config.get("quantize_activation", False): + if not hasattr(self.layer, 'scale_input') or self.layer.scale_input is None: + self.layer.scale_input = torch.nn.Parameter( + torch.ones((), device=device, dtype=dtype), + requires_grad=False + ) + else: + self.layer.scale_input = None + + def forward(self, *args, **kwargs): + """ + FP8 forward pass with optional activation quantization. + Supports Linear layers (Conv2d in future). + """ + # Detect layer type and dispatch + if isinstance(self.layer, torch.nn.Linear): + return self._forward_linear(*args, **kwargs) + else: + raise NotImplementedError( + f"FP8ScaledHandler not implemented for {type(self.layer).__name__}" + ) + + def _forward_linear(self, input): + """FP8 forward for Linear layers""" + # Try fast path with fp8_linear if enabled + if self.config.get("use_fp8_matmul", False) and not self.layer.training: + try: + result = fp8_linear(self.layer, input) + if result is not None: + return result + except Exception as e: + logging.debug(f"FP8 matmul failed, falling back to standard path: {e}") + + # Standard path: dequantize and compute + weight, bias = cast_bias_weight(self.layer, input) + + # Dequantize weight + scale = self.layer.scale_weight.to(device=weight.device, dtype=weight.dtype) + + # Apply weight functions (LoRA, etc.) - they see dequantized weights + if hasattr(self.layer, 'weight_function') and len(self.layer.weight_function) > 0: + weight = weight * scale + for f in self.layer.weight_function: + weight = f(weight) + else: + weight = weight * scale + + if hasattr(self.layer, 'bias_function') and len(self.layer.bias_function) > 0: + for f in self.layer.bias_function: + bias = f(bias) if bias is not None else None + + # Execute linear operation + # Optimization: multiply by scale on smaller tensor + if weight.numel() < input.numel() and len(self.layer.weight_function) == 0: + return torch.nn.functional.linear(input, weight, bias) + else: + return torch.nn.functional.linear(input, weight, bias) + + def load_state_dict(self, state_dict, prefix): + """Load scale parameters from state dict""" + scale_weight_key = f"{prefix}scale_weight" + if scale_weight_key in state_dict: + self.layer.scale_weight.data.copy_(state_dict[scale_weight_key]) + + scale_input_key = f"{prefix}scale_input" + if scale_input_key in state_dict and self.layer.scale_input is not None: + self.layer.scale_input.data.copy_(state_dict[scale_input_key]) + + def state_dict(self, prefix): + """Save scale parameters to state dict""" + result = {f"{prefix}scale_weight": self.layer.scale_weight} + if self.layer.scale_input is not None: + result[f"{prefix}scale_input"] = self.layer.scale_input + return result + + def convert_weight(self, weight, inplace=False): + """Dequantize: multiply by scale""" + scale = self.layer.scale_weight.to(device=weight.device, dtype=weight.dtype) + if inplace: + weight *= scale + return weight + return weight * scale + + def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False): + """Quantize: divide by scale with stochastic rounding""" + scale = self.layer.scale_weight.to(device=weight.device, dtype=weight.dtype) + quantized = comfy.float.stochastic_rounding( + weight / scale, + self.layer.weight.dtype, + seed=seed + ) + + if return_weight: + return quantized + + if inplace_update: + self.layer.weight.data.copy_(quantized) + else: + self.layer.weight = torch.nn.Parameter(quantized, requires_grad=False) + + +# ============================================================================== +# Mixed Precision Operations +# ============================================================================== + +class MixedPrecisionOps(disable_weight_init): + """ + Operations class supporting per-layer quantization (mixed precision). + + This class enables different layers to use different quantization formats + within the same model (e.g., some layers FP8, others BF16). + + Layer-specific quantization is configured via _layer_quant_config class variable, + which is set by pick_operations() when a model has mixed precision. + """ + + _layer_quant_config = {} # Class variable set by pick_operations() + + class Linear(disable_weight_init.Linear): + """Linear layer with optional per-layer quantization""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.quant_handler = None + self._handler_initialized = False + + def reset_parameters(self): + # Don't allocate weights - return None like disable_weight_init + return None + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs): + """ + Called by PyTorch during load_state_dict. + This is where we initialize the handler since we now know the layer name. + """ + if not self._handler_initialized: + # Normalize layer name from prefix + layer_name = prefix.rstrip('.') + + # Strip known model prefixes + for model_prefix in ["model.diffusion_model.", "model.model.", "net."]: + if layer_name.startswith(model_prefix): + layer_name = layer_name[len(model_prefix):] + break + + # Check if this layer has quantization config + # Access via parent class since _layer_quant_config is a class variable + if layer_name in MixedPrecisionOps._layer_quant_config: + config = MixedPrecisionOps._layer_quant_config[layer_name] + try: + format_info = QuantFormatRegistry.get( + config["format"], + **config.get("params", {}) + ) + + # Initialize handler + self.quant_handler = format_info["handler"](self, **format_info["config"]) + self.quant_handler.setup_parameters() + + # Let handler load its parameters (scale_weight, etc.) + self.quant_handler.load_state_dict(state_dict, prefix) + + logging.debug(f"Initialized {config['format']} handler for layer {layer_name}") + except ValueError as e: + # Format not registered - fall back to standard precision + logging.warning( + f"Quantization format '{config['format']}' not registered for layer {layer_name}. " + f"Falling back to standard precision. Error: {e}" + ) + self.quant_handler = None + except Exception as e: + logging.error(f"Failed to initialize quantization handler for {layer_name}: {e}") + self.quant_handler = None + + self._handler_initialized = True + + # Call parent to load weight and bias + super()._load_from_state_dict( + state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs + ) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + """Save layer parameters including quantization metadata""" + super()._save_to_state_dict(destination, prefix, keep_vars) + + # Save handler parameters (scale_weight, etc.) + if self.quant_handler: + handler_dict = self.quant_handler.state_dict(prefix) + destination.update(handler_dict) + + def forward_comfy_cast_weights(self, input): + """Forward pass with optional quantization""" + if self.quant_handler: + # Use handler for quantized forward + return self.quant_handler.forward(input) + else: + # Standard path for non-quantized layers + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.linear(input, weight, bias) + + def forward(self, *args, **kwargs): + """Main forward pass""" + run_every_op() + # Same logic as disable_weight_init.Linear + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + + @classmethod + def conv_nd(s, dims, *args, **kwargs): + """Create Conv layer (same as disable_weight_init)""" + if dims == 2: + return s.Conv2d(*args, **kwargs) + elif dims == 3: + return s.Conv3d(*args, **kwargs) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): + """ + Select appropriate operations class for model. + + NEW: If model_config.layer_quant_config exists, returns MixedPrecisionOps (Phase 3). + LEGACY: All other paths unchanged for backward compatibility. + + Args: + weight_dtype: Weight storage dtype + compute_dtype: Computation dtype + load_device: Device for loading + disable_fast_fp8: Disable fast FP8 paths + fp8_optimizations: Enable FP8 optimizations + scaled_fp8: Legacy FP8 dtype marker + model_config: Model config object (optional, for mixed precision support) + + Returns: + Operations class (e.g., MixedPrecisionOps, fp8_ops, disable_weight_init) + """ + # NEW: Check for mixed precision + if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: + MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config + logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") + return MixedPrecisionOps + + # LEGACY paths (unchanged) fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) @@ -483,3 +933,35 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_ return disable_weight_init return manual_cast + + +# ============================================================================== +# Register built-in quantization formats +# ============================================================================== + +# FP8 E4M3FN weight-only quantization +QuantFormatRegistry.register( + "fp8_e4m3fn_scaled", + handler_class=FP8ScaledHandler, + base_dtype=torch.float8_e4m3fn, + quantize_activation=False, + use_fp8_matmul=True, +) + +# FP8 E4M3FN weight+activation quantization +QuantFormatRegistry.register( + "fp8_e4m3fn_scaled_dynamic", + handler_class=FP8ScaledHandler, + base_dtype=torch.float8_e4m3fn, + quantize_activation=True, + use_fp8_matmul=True, +) + +# FP8 E5M2 weight-only quantization +QuantFormatRegistry.register( + "fp8_e5m2_scaled", + handler_class=FP8ScaledHandler, + base_dtype=torch.float8_e5m2, + quantize_activation=False, + use_fp8_matmul=True, +) diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 54573abb1..e4bd74514 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -50,6 +50,7 @@ class BASE: manual_cast_dtype = None custom_operations = None scaled_fp8 = None + layer_quant_config = None # Per-layer quantization configuration for mixed precision optimizations = {"fp8": False} @classmethod From ad6c14c37033f5921c37ae225d1fa8c5a5325e5a Mon Sep 17 00:00:00 2001 From: lspindler Date: Wed, 22 Oct 2025 10:30:00 +0200 Subject: [PATCH 2/7] Updated design using Tensor Subclasses --- comfy/model_detection.py | 4 +- comfy/ops.py | 514 ++++-------------- comfy/quant_ops.py | 346 ++++++++++++ tests-unit/comfy_test/test_mixed_precision.py | 274 ++++++++++ tests-unit/comfy_test/test_quant_detection.py | 262 +++++++++ tests-unit/comfy_test/test_quant_registry.py | 399 ++++++++++++++ 6 files changed, 1400 insertions(+), 399 deletions(-) create mode 100644 comfy/quant_ops.py create mode 100644 tests-unit/comfy_test/test_mixed_precision.py create mode 100644 tests-unit/comfy_test/test_quant_detection.py create mode 100644 tests-unit/comfy_test/test_quant_registry.py diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 7ce9aaa9a..01f26836b 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -107,10 +107,10 @@ def detect_layer_quantization(state_dict, prefix="model.diffusion_model."): if weight_key in state_dict: weight_dtype = state_dict[weight_key].dtype if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - format_name = "fp8_e4m3fn_scaled" if weight_dtype == torch.float8_e4m3fn else "fp8_e5m2_scaled" + format_name = "fp8_e4m3fn" if weight_dtype == torch.float8_e4m3fn else "fp8_e5m2" layer_configs[layer_name] = { "format": format_name, - "params": {"use_fp8_matmul": True} + "params": {} } elif key.endswith(".weight") and not key.endswith(".scale_weight"): layer_name = normalize_layer_name(key, known_prefixes) diff --git a/comfy/ops.py b/comfy/ops.py index 7ce7d3293..2e6782dbd 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -324,6 +324,21 @@ class manual_cast(disable_weight_init): def fp8_linear(self, input): + """ + Legacy FP8 linear function - now uses tensor subclass infrastructure. + + This function maintains backward compatibility with existing code while + routing all FP8 computation through the unified tensor subclass system. + All actual FP8 matmul logic is handled by the registered operation handlers + in quant_ops.py via __torch_dispatch__. + + Args: + self: Linear layer with FP8 weight and scale parameters + input: Input tensor (any dtype) + + Returns: + Output tensor or None if weight is not FP8 + """ dtype = self.weight.dtype if dtype not in [torch.float8_e4m3fn]: return None @@ -335,10 +350,12 @@ def fp8_linear(self, input): input_shape = input.shape input_dtype = input.dtype + if len(input.shape) == 3: + # Get weight and bias using standard casting w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) - w = w.t() + # Get scales (same as before) scale_weight = self.scale_weight scale_input = self.scale_input if scale_weight is None: @@ -348,23 +365,31 @@ def fp8_linear(self, input): if scale_input is None: scale_input = torch.ones((), device=input.device, dtype=torch.float32) - input = torch.clamp(input, min=-448, max=448, out=input) - input = input.reshape(-1, input_shape[2]).to(dtype).contiguous() else: scale_input = scale_input.to(input.device) - input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous() - - if bias is not None: - o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) + + # Wrap weight in QuantizedTensorFP8 - this enables unified dispatch + quantized_weight = QuantizedTensorFP8(w, scale_weight, orig_dtype=input_dtype) + + # Handle input quantization and wrapping + if self.scale_input is None: + # Clamp input to FP8 range and quantize + input = torch.clamp(input, min=-448, max=448, out=input) + input_fp8 = input.reshape(-1, input_shape[2]).to(dtype).contiguous() else: - o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight) - - if isinstance(o, tuple): - o = o[0] - + # Apply inverse scale and quantize + input_fp8 = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous() + + # Wrap input in QuantizedTensorFP8 + quantized_input = QuantizedTensorFP8(input_fp8, scale_input, orig_dtype=input_dtype) + + # Call F.linear - __torch_dispatch__ routes to handle_linear_fp8 in quant_ops.py! + # This is the key unification: all FP8 computation goes through one path + o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) + + # Reshape output if tensor_2d: return o.reshape(input_shape[0], -1) - return o.reshape((-1, input_shape[1], self.weight.shape[0])) return None @@ -459,307 +484,8 @@ if CUBLAS_IS_AVAILABLE: return super().forward(*args, **kwargs) -# ============================================================================== -# Quantization Format Registry System -# ============================================================================== - -class QuantFormatHandler: - """ - Base class for all quantization format handlers. - - A handler encapsulates the logic for a specific quantization format - (e.g., FP8 scaled, MX formats) and manages the quantization - parameters and forward pass for quantized layers. - """ - - def __init__(self, layer, **config): - """ - Initialize handler for a specific layer. - - Args: - layer: The nn.Module layer (Linear, Conv2d, etc.) - **config: Format-specific configuration - """ - self.layer = layer - self.config = config - - def setup_parameters(self): - """ - Initialize quantization parameters on the layer. - Called during layer construction or load_state_dict. - - Subclasses should create parameters like scale_weight, scale_input, etc. - and attach them to self.layer. - """ - raise NotImplementedError(f"{self.__class__.__name__} must implement setup_parameters()") - - def forward(self, *args, **kwargs): - """ - Execute quantized forward pass. - - Signature matches the layer's expected forward pass. - Handler accesses layer parameters via self.layer (weight, bias, etc.) - - Args: - *args: Positional arguments matching layer forward signature - **kwargs: Keyword arguments matching layer forward signature - - Returns: - Layer output tensor - - Examples: - Linear: forward(input) - Conv2d: forward(input) - GroupNorm: forward(input) - """ - raise NotImplementedError(f"{self.__class__.__name__} must implement forward()") - - def load_state_dict(self, state_dict, prefix): - """ - Load quantization parameters from state dict. - - Args: - state_dict: State dictionary - prefix: Key prefix for this layer (e.g., "model.diffusion_model.layer1.") - """ - raise NotImplementedError(f"{self.__class__.__name__} must implement load_state_dict()") - - def state_dict(self, prefix): - """ - Save quantization parameters to state dict. - - Args: - prefix: Key prefix for this layer - - Returns: - Dictionary of quantization parameters with full keys - """ - raise NotImplementedError(f"{self.__class__.__name__} must implement state_dict()") - - def convert_weight(self, weight, inplace=False): - """ - Convert weight from quantized to full precision (dequantize). - - Args: - weight: Quantized weight tensor - inplace: Whether to modify in-place - - Returns: - Dequantized weight tensor - """ - raise NotImplementedError(f"{self.__class__.__name__} must implement convert_weight()") - - def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False): - """ - Convert and set weight from full precision to quantized. - - Args: - weight: Full precision weight tensor - inplace_update: Whether to update layer weight in-place - seed: Random seed for stochastic rounding - return_weight: If True, return quantized weight without setting - - Returns: - Quantized weight if return_weight=True, else None - """ - raise NotImplementedError(f"{self.__class__.__name__} must implement set_weight()") - - -class QuantFormatRegistry: - """ - Global registry for quantization formats. - - Formats are registered with a unique name and handler class. - Custom formats can be registered by custom nodes. - """ - - _formats = {} - - @classmethod - def register(cls, name, handler_class, **default_config): - """ - Register a new quantization format. - - Args: - name: Unique format identifier (e.g., "fp8_e4m3fn_scaled") - handler_class: Handler class implementing QuantFormatHandler - **default_config: Default configuration parameters - - Example: - QuantFormatRegistry.register( - "fp8_e4m3fn_scaled", - handler_class=FP8ScaledHandler, - base_dtype=torch.float8_e4m3fn, - quantize_activation=False, - use_fp8_matmul=True, - ) - """ - if not issubclass(handler_class, QuantFormatHandler): - raise TypeError(f"handler_class must be a subclass of QuantFormatHandler, got {handler_class}") - - cls._formats[name] = { - "handler": handler_class, - "config": default_config.copy() - } - logging.debug(f"Registered quantization format: {name}") - - @classmethod - def get(cls, name, **override_config): - """ - Get format info with optional config overrides. - - Args: - name: Format identifier - **override_config: Configuration overrides - - Returns: - Dict with 'handler' (class) and 'config' (dict) keys - - Raises: - ValueError: If format name not registered - """ - if name not in cls._formats: - available = ", ".join(cls._formats.keys()) if cls._formats else "none" - raise ValueError(f"Unknown quantization format: '{name}'. Available formats: {available}") - - format_info = cls._formats[name].copy() - # Merge override_config into default config - config = format_info["config"].copy() - config.update(override_config) - format_info["config"] = config - return format_info - - @classmethod - def list_formats(cls): - """List all registered format names""" - return list(cls._formats.keys()) - - @classmethod - def is_registered(cls, name): - """Check if a format is registered""" - return name in cls._formats - - -class FP8ScaledHandler(QuantFormatHandler): - """ - Handler for FP8 quantization with per-tensor scaling. - - Supports both weight-only and weight+activation quantization. - Compatible with existing fp8_linear implementation. - """ - - def setup_parameters(self): - """Initialize scale_weight and optionally scale_input""" - device = self.layer.weight.device - dtype = torch.float32 - - # Always have scale_weight for FP8 - if not hasattr(self.layer, 'scale_weight') or self.layer.scale_weight is None: - self.layer.scale_weight = torch.nn.Parameter( - torch.ones((), device=device, dtype=dtype), - requires_grad=False - ) - - # scale_input is optional (for activation quantization) - if self.config.get("quantize_activation", False): - if not hasattr(self.layer, 'scale_input') or self.layer.scale_input is None: - self.layer.scale_input = torch.nn.Parameter( - torch.ones((), device=device, dtype=dtype), - requires_grad=False - ) - else: - self.layer.scale_input = None - - def forward(self, *args, **kwargs): - """ - FP8 forward pass with optional activation quantization. - Supports Linear layers (Conv2d in future). - """ - # Detect layer type and dispatch - if isinstance(self.layer, torch.nn.Linear): - return self._forward_linear(*args, **kwargs) - else: - raise NotImplementedError( - f"FP8ScaledHandler not implemented for {type(self.layer).__name__}" - ) - - def _forward_linear(self, input): - """FP8 forward for Linear layers""" - # Try fast path with fp8_linear if enabled - if self.config.get("use_fp8_matmul", False) and not self.layer.training: - try: - result = fp8_linear(self.layer, input) - if result is not None: - return result - except Exception as e: - logging.debug(f"FP8 matmul failed, falling back to standard path: {e}") - - # Standard path: dequantize and compute - weight, bias = cast_bias_weight(self.layer, input) - - # Dequantize weight - scale = self.layer.scale_weight.to(device=weight.device, dtype=weight.dtype) - - # Apply weight functions (LoRA, etc.) - they see dequantized weights - if hasattr(self.layer, 'weight_function') and len(self.layer.weight_function) > 0: - weight = weight * scale - for f in self.layer.weight_function: - weight = f(weight) - else: - weight = weight * scale - - if hasattr(self.layer, 'bias_function') and len(self.layer.bias_function) > 0: - for f in self.layer.bias_function: - bias = f(bias) if bias is not None else None - - # Execute linear operation - # Optimization: multiply by scale on smaller tensor - if weight.numel() < input.numel() and len(self.layer.weight_function) == 0: - return torch.nn.functional.linear(input, weight, bias) - else: - return torch.nn.functional.linear(input, weight, bias) - - def load_state_dict(self, state_dict, prefix): - """Load scale parameters from state dict""" - scale_weight_key = f"{prefix}scale_weight" - if scale_weight_key in state_dict: - self.layer.scale_weight.data.copy_(state_dict[scale_weight_key]) - - scale_input_key = f"{prefix}scale_input" - if scale_input_key in state_dict and self.layer.scale_input is not None: - self.layer.scale_input.data.copy_(state_dict[scale_input_key]) - - def state_dict(self, prefix): - """Save scale parameters to state dict""" - result = {f"{prefix}scale_weight": self.layer.scale_weight} - if self.layer.scale_input is not None: - result[f"{prefix}scale_input"] = self.layer.scale_input - return result - - def convert_weight(self, weight, inplace=False): - """Dequantize: multiply by scale""" - scale = self.layer.scale_weight.to(device=weight.device, dtype=weight.dtype) - if inplace: - weight *= scale - return weight - return weight * scale - - def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False): - """Quantize: divide by scale with stochastic rounding""" - scale = self.layer.scale_weight.to(device=weight.device, dtype=weight.dtype) - quantized = comfy.float.stochastic_rounding( - weight / scale, - self.layer.weight.dtype, - seed=seed - ) - - if return_weight: - return quantized - - if inplace_update: - self.layer.weight.data.copy_(quantized) - else: - self.layer.weight = torch.nn.Parameter(quantized, requires_grad=False) +# Import quantization operations from separate module +from .quant_ops import QuantizedTensorFP8 # ============================================================================== @@ -780,12 +506,13 @@ class MixedPrecisionOps(disable_weight_init): _layer_quant_config = {} # Class variable set by pick_operations() class Linear(disable_weight_init.Linear): - """Linear layer with optional per-layer quantization""" + """Linear layer with optional per-layer quantization using tensor subclasses""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.quant_handler = None - self._handler_initialized = False + self.quant_format = None + self.quant_scale = None + self._quantization_initialized = False def reset_parameters(self): # Don't allocate weights - return None like disable_weight_init @@ -795,9 +522,16 @@ class MixedPrecisionOps(disable_weight_init): strict, missing_keys, unexpected_keys, error_msgs): """ Called by PyTorch during load_state_dict. - This is where we initialize the handler since we now know the layer name. + Load weight and wrap in QuantizedTensorFP8 if this layer is quantized. """ - if not self._handler_initialized: + # Call parent to load weight and bias first + super()._load_from_state_dict( + state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs + ) + + # After weight is loaded, wrap it if this layer is quantized + if not self._quantization_initialized: # Normalize layer name from prefix layer_name = prefix.rstrip('.') @@ -808,60 +542,78 @@ class MixedPrecisionOps(disable_weight_init): break # Check if this layer has quantization config - # Access via parent class since _layer_quant_config is a class variable if layer_name in MixedPrecisionOps._layer_quant_config: config = MixedPrecisionOps._layer_quant_config[layer_name] - try: - format_info = QuantFormatRegistry.get( - config["format"], - **config.get("params", {}) - ) + self.quant_format = config.get("format", "fp8_e4m3fn") + + # Load scale parameter + scale_key = f"{prefix}scale_weight" + if scale_key in state_dict: + self.quant_scale = state_dict[scale_key] - # Initialize handler - self.quant_handler = format_info["handler"](self, **format_info["config"]) - self.quant_handler.setup_parameters() - - # Let handler load its parameters (scale_weight, etc.) - self.quant_handler.load_state_dict(state_dict, prefix) - - logging.debug(f"Initialized {config['format']} handler for layer {layer_name}") - except ValueError as e: - # Format not registered - fall back to standard precision - logging.warning( - f"Quantization format '{config['format']}' not registered for layer {layer_name}. " - f"Falling back to standard precision. Error: {e}" - ) - self.quant_handler = None - except Exception as e: - logging.error(f"Failed to initialize quantization handler for {layer_name}: {e}") - self.quant_handler = None + # Wrap weight in QuantizedTensorFP8 + if self.weight is not None and self.weight.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + try: + # Determine original dtype (default to bfloat16) + orig_dtype = torch.bfloat16 + + # Wrap weight in quantized tensor subclass + quantized_weight = QuantizedTensorFP8( + self.weight.data, + self.quant_scale, + orig_dtype=orig_dtype + ) + + # Replace weight parameter with wrapped version + self.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) + + logging.debug(f"Wrapped layer {layer_name} weight in QuantizedTensorFP8 (format: {self.quant_format})") + except Exception as e: + logging.warning(f"Failed to wrap layer {layer_name} in QuantizedTensorFP8: {e}") + self.quant_format = None + self.quant_scale = None + else: + logging.debug(f"Layer {layer_name} has scale but weight dtype is not FP8, skipping quantization") + self.quant_format = None + self.quant_scale = None + else: + logging.debug(f"Layer {layer_name} has quant config but no scale_weight in state_dict") + self.quant_format = None - self._handler_initialized = True - - # Call parent to load weight and bias - super()._load_from_state_dict( - state_dict, prefix, local_metadata, - strict, missing_keys, unexpected_keys, error_msgs - ) + self._quantization_initialized = True def _save_to_state_dict(self, destination, prefix, keep_vars): - """Save layer parameters including quantization metadata""" - super()._save_to_state_dict(destination, prefix, keep_vars) - - # Save handler parameters (scale_weight, etc.) - if self.quant_handler: - handler_dict = self.quant_handler.state_dict(prefix) - destination.update(handler_dict) + """Save layer parameters including quantization scale""" + # First unwrap the weight if it's quantized + if isinstance(self.weight, torch.nn.Parameter) and isinstance(self.weight.data, QuantizedTensorFP8): + # Temporarily unwrap to save the raw FP8 data + quantized_tensor = self.weight.data + raw_fp8_data = quantized_tensor._raw_data + original_weight = self.weight + self.weight = torch.nn.Parameter(raw_fp8_data, requires_grad=False) + + # Call parent to save unwrapped weight + super()._save_to_state_dict(destination, prefix, keep_vars) + + # Restore the wrapped weight + self.weight = original_weight + + # Save the scale parameter + if self.quant_scale is not None: + destination[f"{prefix}scale_weight"] = self.quant_scale if keep_vars else self.quant_scale.detach() + else: + # Standard path for non-quantized weights + super()._save_to_state_dict(destination, prefix, keep_vars) def forward_comfy_cast_weights(self, input): - """Forward pass with optional quantization""" - if self.quant_handler: - # Use handler for quantized forward - return self.quant_handler.forward(input) - else: - # Standard path for non-quantized layers - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.linear(input, weight, bias) + """ + Forward pass - tensor subclass handles dispatch automatically! + __torch_dispatch__ will route to registered handlers based on tensor types. + """ + weight, bias = cast_bias_weight(self, input) + + # Call F.linear - if weight is QuantizedTensorFP8, __torch_dispatch__ handles it! + return torch.nn.functional.linear(input, weight, bias) def forward(self, *args, **kwargs): """Main forward pass""" @@ -933,35 +685,3 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_ return disable_weight_init return manual_cast - - -# ============================================================================== -# Register built-in quantization formats -# ============================================================================== - -# FP8 E4M3FN weight-only quantization -QuantFormatRegistry.register( - "fp8_e4m3fn_scaled", - handler_class=FP8ScaledHandler, - base_dtype=torch.float8_e4m3fn, - quantize_activation=False, - use_fp8_matmul=True, -) - -# FP8 E4M3FN weight+activation quantization -QuantFormatRegistry.register( - "fp8_e4m3fn_scaled_dynamic", - handler_class=FP8ScaledHandler, - base_dtype=torch.float8_e4m3fn, - quantize_activation=True, - use_fp8_matmul=True, -) - -# FP8 E5M2 weight-only quantization -QuantFormatRegistry.register( - "fp8_e5m2_scaled", - handler_class=FP8ScaledHandler, - base_dtype=torch.float8_e5m2, - quantize_activation=False, - use_fp8_matmul=True, -) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py new file mode 100644 index 000000000..681eb9134 --- /dev/null +++ b/comfy/quant_ops.py @@ -0,0 +1,346 @@ +import torch +import logging + +# ============================================================================== +# Global Operation Registry +# ============================================================================== + +# Global operation registry: torch operation → handler function +_QUANT_OP_REGISTRY = {} + +def register_quant_op(torch_op): + """ + Decorator to register an operation handler. + + Example: + @register_quant_op(torch.ops.aten.linear.default) + def handle_linear_fp8(func, args, kwargs): + # Implementation + ... + """ + def decorator(handler_func): + _QUANT_OP_REGISTRY[torch_op] = handler_func + return handler_func + return decorator + + +def get_quant_handler(torch_op): + """Get registered handler for an operation""" + return _QUANT_OP_REGISTRY.get(torch_op) + + +def list_registered_ops(): + """List all registered quantized operations""" + return list(_QUANT_OP_REGISTRY.keys()) + + +# ============================================================================== +# comfy_kitchen Integration +# ============================================================================== + +try: + import comfy_kitchen as ck + ck.disable_backend("cutile") + _CK_AVAILABLE = True + logging.info("comfy_kitchen available for optimized quantization kernels") +except ImportError: + ck = None + _CK_AVAILABLE = False + logging.info("comfy_kitchen not available - using PyTorch fallbacks") +except Exception as e: + ck = None + _CK_AVAILABLE = False + logging.warning(f"comfy_kitchen import failed: {e} - using PyTorch fallbacks") + + +# ============================================================================== +# Quantized Tensor Subclass +# ============================================================================== + +class QuantizedTensorFP8(torch.Tensor): + """ + Tensor subclass for FP8 quantized data. + Automatically handles operations via __torch_dispatch__. + """ + + @staticmethod + def __new__(cls, tensor, scale, orig_dtype=torch.bfloat16): + """ + Create a quantized FP8 tensor. + + Args: + tensor: The FP8 tensor data (torch.float8_e4m3fn or e5m2) + scale: Scale factor for dequantization (scalar tensor) + orig_dtype: Original dtype before quantization + """ + return torch.Tensor._make_subclass(cls, tensor, require_grad=False) + + def __init__(self, tensor, scale, orig_dtype=torch.bfloat16): + self._scale = scale + self._orig_dtype = orig_dtype + # Store a reference to prevent infinite recursion in dequantize + self._raw_data = tensor + + def __repr__(self): + return (f"QuantizedTensorFP8(shape={self.shape}, " + f"scale={self._scale:.4f}, dtype={self._orig_dtype})") + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + """ + Intercept ALL torch operations. + Routes to registered handlers or falls back to dequantization. + """ + kwargs = kwargs or {} + + # Special case: skip dispatch for internal tensor operations + # that are used for unwrapping (to avoid recursion) + if func in [torch.ops.aten._to_copy.default, torch.ops.aten.detach.default]: + # For these ops, use the raw data to avoid recursion, but return QuantizedTensorFP8 for detach + if func == torch.ops.aten.detach.default and isinstance(args[0], QuantizedTensorFP8): + # Special handling for detach - return a new QuantizedTensorFP8 + qt = args[0] + detached_data = qt._raw_data.detach() + return QuantizedTensorFP8(detached_data, qt._scale, qt._orig_dtype) + + # For other ops, just unwrap + def unwrap(arg): + if isinstance(arg, QuantizedTensorFP8): + return arg._raw_data + return arg + new_args = tuple(unwrap(a) if not isinstance(a, (list, tuple, dict)) else a for a in args) + return func(*new_args, **kwargs) + + # Look up registered handler for this operation + handler = _QUANT_OP_REGISTRY.get(func) + if handler: + return handler(func, args, kwargs) + + # No handler - dequantize and use standard path + return cls._dequant_and_fallback(func, args, kwargs) + + @classmethod + def _dequant_and_fallback(cls, func, args, kwargs): + """Fallback: dequantize all quantized tensors""" + def dequant_arg(arg): + if isinstance(arg, QuantizedTensorFP8): + return arg.dequantize() + elif isinstance(arg, (list, tuple)): + return type(arg)(dequant_arg(a) for a in arg) + return arg + + new_args = dequant_arg(args) + new_kwargs = dequant_arg(kwargs) + return func(*new_args, **new_kwargs) + + def dequantize(self) -> torch.Tensor: + """Explicit dequantization""" + # Use the raw data and convert directly + # Call aten ops directly to minimize dispatch interference + plain_tensor = torch.ops.aten._to_copy.default(self._raw_data, dtype=self._orig_dtype) + # Multiply by scale + return plain_tensor * self._scale + + def detach(self): + """Detach returns a new QuantizedTensorFP8 (required for Parameter)""" + # Detach the raw data and create a new QuantizedTensorFP8 + detached_data = self._raw_data.detach() + return QuantizedTensorFP8(detached_data, self._scale, self._orig_dtype) + + +# ============================================================================== +# Operation Handlers for Quantized Tensors +# ============================================================================== + +@register_quant_op(torch.ops.aten.linear.default) +def handle_linear_fp8(func, args, kwargs): + """ + Handle F.linear() with quantized inputs. + + Supports: + - QuantizedTensorFP8 input + QuantizedTensorFP8 weight + - QuantizedTensorFP8 input + regular weight + - Regular input + QuantizedTensorFP8 weight + """ + input_tensor = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + + # Case 1: Both input and weight are FP8 + if isinstance(input_tensor, QuantizedTensorFP8) and isinstance(weight, QuantizedTensorFP8): + # Use _scaled_mm for FP8×FP8 matmul + # Get plain tensors to avoid dispatch recursion + plain_input = input_tensor._raw_data + plain_weight = weight._raw_data + weight_t = plain_weight.t().contiguous() + + try: + if bias is not None: + output = torch._scaled_mm( + plain_input, + weight_t, + out_dtype=input_tensor._orig_dtype, + bias=bias, + scale_a=input_tensor._scale, + scale_b=weight._scale + ) + else: + output = torch._scaled_mm( + plain_input, + weight_t, + out_dtype=input_tensor._orig_dtype, + scale_a=input_tensor._scale, + scale_b=weight._scale + ) + + if isinstance(output, tuple): + output = output[0] + + # Check if output is FP8 (some architectures support this) + if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + # Keep quantized! + output_scale = input_tensor._scale * weight._scale + return QuantizedTensorFP8(output, output_scale, input_tensor._orig_dtype) + else: + return output + except Exception as e: + logging.debug(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") + # Fall through to dequantization path + + # Case 2: Only weight is quantized + if isinstance(weight, QuantizedTensorFP8): + weight_dq = weight.dequantize() + input_dq = input_tensor.dequantize() if isinstance(input_tensor, QuantizedTensorFP8) else input_tensor + return torch.nn.functional.linear(input_dq, weight_dq, bias) + + # Case 3: Only input is quantized + elif isinstance(input_tensor, QuantizedTensorFP8): + input_dq = input_tensor.dequantize() + return torch.nn.functional.linear(input_dq, weight, bias) + + # Case 4: Neither is quantized (shouldn't happen, but handle it) + else: + return torch.nn.functional.linear(input_tensor, weight, bias) + + +@register_quant_op(torch.ops.aten.silu.default) +def handle_silu_fp8(func, args, kwargs): + """ + SiLU can be computed approximately on FP8. + Keeps activations quantized for next layer. + """ + input_q = args[0] + + if not isinstance(input_q, QuantizedTensorFP8): + # Not quantized, use standard path + return torch.nn.functional.silu(input_q) + + # Compute SiLU while keeping quantized + # SiLU(x) = x * sigmoid(x) + + # Get plain tensor to avoid dispatch recursion + plain_tensor = input_q._raw_data + + # Upcast to FP16 for sigmoid stability + x_fp16 = plain_tensor.to(torch.float16) + sigmoid_fp16 = torch.sigmoid(x_fp16 * input_q._scale) + result_fp16 = x_fp16 * sigmoid_fp16 + + # Convert back to FP8 + result_fp8 = result_fp16.to(plain_tensor.dtype) + + # Return quantized (scale approximately preserved) + return QuantizedTensorFP8(result_fp8, input_q._scale, input_q._orig_dtype) + + +@register_quant_op(torch.ops.aten.layer_norm.default) +def handle_layernorm_fp8(func, args, kwargs): + """ + LayerNorm requires high precision. + Dequantizes input and returns standard tensor. + """ + input_q = args[0] + normalized_shape = args[1] + weight = args[2] if len(args) > 2 else None + bias = args[3] if len(args) > 3 else None + eps = args[4] if len(args) > 4 else 1e-5 + + # Dequantize if needed + if isinstance(input_q, QuantizedTensorFP8): + x = input_q.dequantize() + else: + x = input_q + + # Standard LayerNorm + result = torch.nn.functional.layer_norm(x, normalized_shape, weight, bias, eps) + + # Return dequantized (next layer will quantize if needed) + return result + + +@register_quant_op(torch.ops.aten.group_norm.default) +def handle_groupnorm_fp8(func, args, kwargs): + """ + GroupNorm requires high precision. + Dequantizes input and returns standard tensor. + """ + input_q = args[0] + num_groups = args[1] + weight = args[2] if len(args) > 2 else None + bias = args[3] if len(args) > 3 else None + eps = args[4] if len(args) > 4 else 1e-5 + + # Dequantize if needed + if isinstance(input_q, QuantizedTensorFP8): + x = input_q.dequantize() + else: + x = input_q + + # Standard GroupNorm + result = torch.nn.functional.group_norm(x, num_groups, weight, bias, eps) + + # Return dequantized + return result + + +@register_quant_op(torch.ops.aten.add.Tensor) +def handle_add_fp8(func, args, kwargs): + """ + Handle addition with mixed quantized/non-quantized tensors. + """ + a = args[0] + b = args[1] + + # If both are quantized, dequantize both + if isinstance(a, QuantizedTensorFP8) and isinstance(b, QuantizedTensorFP8): + return a.dequantize() + b.dequantize() + # If only one is quantized, dequantize it + elif isinstance(a, QuantizedTensorFP8): + return a.dequantize() + b + elif isinstance(b, QuantizedTensorFP8): + return a + b.dequantize() + # Neither is quantized + else: + return a + b + + +@register_quant_op(torch.ops.aten.mul.Tensor) +def handle_mul_fp8(func, args, kwargs): + """ + Handle multiplication with mixed quantized/non-quantized tensors. + """ + a = args[0] + b = args[1] + + # If both are quantized, dequantize both + if isinstance(a, QuantizedTensorFP8) and isinstance(b, QuantizedTensorFP8): + return a.dequantize() * b.dequantize() + # If only one is quantized, dequantize it + elif isinstance(a, QuantizedTensorFP8): + return a.dequantize() * b + elif isinstance(b, QuantizedTensorFP8): + return a * b.dequantize() + # Neither is quantized + else: + return a * b + diff --git a/tests-unit/comfy_test/test_mixed_precision.py b/tests-unit/comfy_test/test_mixed_precision.py new file mode 100644 index 000000000..cbfa2866d --- /dev/null +++ b/tests-unit/comfy_test/test_mixed_precision.py @@ -0,0 +1,274 @@ +""" +End-to-end tests for mixed precision quantization. +Tests Phase 3: Mixed Precision Operations +""" + +import unittest +import torch +import sys +import os + +# Add comfy to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +from comfy import ops + + +class SimpleModel(torch.nn.Module): + """Simple model for testing mixed precision""" + def __init__(self, operations=ops.disable_weight_init): + super().__init__() + self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16) + self.layer2 = operations.Linear(20, 30, device="cpu", dtype=torch.bfloat16) + self.layer3 = operations.Linear(30, 40, device="cpu", dtype=torch.bfloat16) + + def forward(self, x): + x = self.layer1(x) + x = torch.nn.functional.relu(x) + x = self.layer2(x) + x = torch.nn.functional.relu(x) + x = self.layer3(x) + return x + + +class TestMixedPrecisionOps(unittest.TestCase): + """Test MixedPrecisionOps end-to-end""" + + def test_all_layers_standard(self): + """Test that model with no quantization works normally""" + # Configure no quantization + ops.MixedPrecisionOps._layer_quant_config = {} + + # Create model + model = SimpleModel(operations=ops.MixedPrecisionOps) + + # Initialize weights manually + model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16)) + model.layer1.bias = torch.nn.Parameter(torch.randn(20, dtype=torch.bfloat16)) + model.layer2.weight = torch.nn.Parameter(torch.randn(30, 20, dtype=torch.bfloat16)) + model.layer2.bias = torch.nn.Parameter(torch.randn(30, dtype=torch.bfloat16)) + model.layer3.weight = torch.nn.Parameter(torch.randn(40, 30, dtype=torch.bfloat16)) + model.layer3.bias = torch.nn.Parameter(torch.randn(40, dtype=torch.bfloat16)) + + # Initialize weight_function and bias_function + for layer in [model.layer1, model.layer2, model.layer3]: + layer.weight_function = [] + layer.bias_function = [] + + # Forward pass + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + output = model(input_tensor) + + self.assertEqual(output.shape, (5, 40)) + self.assertEqual(output.dtype, torch.bfloat16) + + def test_mixed_precision_load(self): + """Test loading a mixed precision model from state dict""" + # Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard + layer_quant_config = { + "layer1": { + "format": "fp8_e4m3fn_scaled", + "params": {"use_fp8_matmul": False} # Disable for CPU testing + }, + "layer3": { + "format": "fp8_e5m2_scaled", + "params": {"use_fp8_matmul": False} + } + } + ops.MixedPrecisionOps._layer_quant_config = layer_quant_config + + # Create state dict with mixed precision + fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) + fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e5m2) + + state_dict = { + # Layer 1: FP8 E4M3FN + "layer1.weight": fp8_weight1, + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer1.scale_weight": torch.tensor(2.0, dtype=torch.float32), + + # Layer 2: Standard BF16 + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + + # Layer 3: FP8 E5M2 + "layer3.weight": fp8_weight3, + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + "layer3.scale_weight": torch.tensor(1.5, dtype=torch.float32), + } + + # Create model and load state dict + model = SimpleModel(operations=ops.MixedPrecisionOps) + model.load_state_dict(state_dict) + + # Verify handlers are set up correctly + self.assertIsNotNone(model.layer1.quant_handler) + self.assertIsNone(model.layer2.quant_handler) # No quantization + self.assertIsNotNone(model.layer3.quant_handler) + + # Verify scales were loaded + self.assertEqual(model.layer1.scale_weight.item(), 2.0) + self.assertEqual(model.layer3.scale_weight.item(), 1.5) + + # Forward pass + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + output = model(input_tensor) + + self.assertEqual(output.shape, (5, 40)) + + def test_state_dict_round_trip(self): + """Test saving and loading state dict preserves quantization""" + # Configure mixed precision + layer_quant_config = { + "layer1": { + "format": "fp8_e4m3fn_scaled", + "params": {"use_fp8_matmul": False} + } + } + ops.MixedPrecisionOps._layer_quant_config = layer_quant_config + + # Create and load model + fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict1 = { + "layer1.weight": fp8_weight, + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer1.scale_weight": torch.tensor(3.0, dtype=torch.float32), + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + } + + model1 = SimpleModel(operations=ops.MixedPrecisionOps) + model1.load_state_dict(state_dict1) + + # Save state dict + state_dict2 = model1.state_dict() + + # Verify scale_weight is saved + self.assertIn("layer1.scale_weight", state_dict2) + self.assertEqual(state_dict2["layer1.scale_weight"].item(), 3.0) + + # Load into new model + model2 = SimpleModel(operations=ops.MixedPrecisionOps) + model2.load_state_dict(state_dict2) + + # Verify handler is set up + self.assertIsNotNone(model2.layer1.quant_handler) + self.assertEqual(model2.layer1.scale_weight.item(), 3.0) + + # Verify forward passes match + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + output1 = model1(input_tensor) + output2 = model2(input_tensor) + + torch.testing.assert_close(output1, output2, rtol=1e-3, atol=1e-3) + + def test_weight_function_compatibility(self): + """Test that weight_function (LoRA) works with quantized layers""" + # Configure FP8 quantization + layer_quant_config = { + "layer1": { + "format": "fp8_e4m3fn_scaled", + "params": {"use_fp8_matmul": False} + } + } + ops.MixedPrecisionOps._layer_quant_config = layer_quant_config + + # Create and load model + fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict = { + "layer1.weight": fp8_weight, + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer1.scale_weight": torch.tensor(2.0, dtype=torch.float32), + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + } + + model = SimpleModel(operations=ops.MixedPrecisionOps) + model.load_state_dict(state_dict) + + # Add a weight function (simulating LoRA) + # LoRA delta must match weight shape (20, 10) + def apply_lora(weight): + # Generate LoRA delta matching weight shape + lora_delta = torch.randn_like(weight) * 0.01 + return weight + lora_delta + + model.layer1.weight_function.append(apply_lora) + + # Forward pass should work with LoRA + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + output = model(input_tensor) + + self.assertEqual(output.shape, (5, 40)) + + def test_error_handling_unknown_format(self): + """Test that unknown formats fall back gracefully""" + # Configure with unknown format + layer_quant_config = { + "layer1": { + "format": "unknown_format_xyz", + "params": {} + } + } + ops.MixedPrecisionOps._layer_quant_config = layer_quant_config + + # Create state dict + state_dict = { + "layer1.weight": torch.randn(20, 10, dtype=torch.bfloat16), + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + } + + # Load should not crash, just log warning + model = SimpleModel(operations=ops.MixedPrecisionOps) + model.load_state_dict(state_dict) + + # Handler should be None (fallback to standard) + self.assertIsNone(model.layer1.quant_handler) + + # Forward pass should still work + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + output = model(input_tensor) + self.assertEqual(output.shape, (5, 40)) + + +class TestPickOperationsWithMixedPrecision(unittest.TestCase): + """Test pick_operations with mixed precision config""" + + def test_pick_operations_with_layer_quant_config(self): + """Test that pick_operations returns MixedPrecisionOps when config present""" + from comfy import supported_models_base + + # Create model config with layer_quant_config + model_config = supported_models_base.BASE({}) + model_config.layer_quant_config = { + "layer1": {"format": "fp8_e4m3fn_scaled", "params": {}} + } + + result = ops.pick_operations(None, None, model_config=model_config) + + self.assertEqual(result, ops.MixedPrecisionOps) + self.assertEqual(ops.MixedPrecisionOps._layer_quant_config, model_config.layer_quant_config) + + def test_pick_operations_without_layer_quant_config(self): + """Test that pick_operations falls back to standard when no config""" + from comfy import supported_models_base + + model_config = supported_models_base.BASE({}) + model_config.layer_quant_config = None + + result = ops.pick_operations(None, None, model_config=model_config) + + self.assertEqual(result, ops.disable_weight_init) + + +if __name__ == "__main__": + unittest.main() + diff --git a/tests-unit/comfy_test/test_quant_detection.py b/tests-unit/comfy_test/test_quant_detection.py new file mode 100644 index 000000000..bb952a81b --- /dev/null +++ b/tests-unit/comfy_test/test_quant_detection.py @@ -0,0 +1,262 @@ +""" +Integration tests for quantization detection. +Tests Phase 2: Detection & Integration +""" + +import unittest +import torch +import sys +import os + +# Add comfy to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +from comfy import model_detection + + +class TestNormalizeLayerName(unittest.TestCase): + """Test the normalize_layer_name helper function""" + + def test_strip_prefix_and_suffix(self): + """Test stripping prefix and suffix""" + known_prefixes = ["model.diffusion_model."] + result = model_detection.normalize_layer_name( + "model.diffusion_model.layer1.weight", + known_prefixes + ) + self.assertEqual(result, "layer1") + + def test_strip_multiple_prefixes(self): + """Test with multiple known prefixes""" + known_prefixes = ["model.diffusion_model.", "model.model.", "net."] + + result1 = model_detection.normalize_layer_name( + "model.diffusion_model.block.attn.weight", + known_prefixes + ) + self.assertEqual(result1, "block.attn") + + result2 = model_detection.normalize_layer_name( + "model.model.encoder.layer.weight", + known_prefixes + ) + self.assertEqual(result2, "encoder.layer") + + result3 = model_detection.normalize_layer_name( + "net.transformer.blocks.0.weight", + known_prefixes + ) + self.assertEqual(result3, "transformer.blocks.0") + + def test_strip_scale_weight_suffix(self): + """Test stripping scale_weight suffix""" + known_prefixes = ["model.diffusion_model."] + result = model_detection.normalize_layer_name( + "model.diffusion_model.layer1.scale_weight", + known_prefixes + ) + self.assertEqual(result, "layer1") + + def test_strip_bias_suffix(self): + """Test stripping bias suffix""" + known_prefixes = ["model.diffusion_model."] + result = model_detection.normalize_layer_name( + "model.diffusion_model.layer1.bias", + known_prefixes + ) + self.assertEqual(result, "layer1") + + def test_no_prefix_match(self): + """Test with no prefix match""" + known_prefixes = ["model.diffusion_model."] + result = model_detection.normalize_layer_name( + "other.model.layer1.weight", + known_prefixes + ) + # Should strip suffix but not prefix + self.assertEqual(result, "other.model.layer1") + + +class TestDetectLayerQuantization(unittest.TestCase): + """Test the detect_layer_quantization function""" + + def test_no_quantization(self): + """Test with no quantization markers""" + state_dict = { + "model.diffusion_model.layer1.weight": torch.randn(10, 20), + "model.diffusion_model.layer2.weight": torch.randn(20, 30), + } + result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") + self.assertIsNone(result) + + def test_legacy_scaled_fp8(self): + """Test that legacy scaled_fp8 marker returns None""" + # Create FP8 tensor by converting from float32 + fp8_weight = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict = { + "model.diffusion_model.scaled_fp8": torch.tensor([], dtype=torch.float8_e4m3fn), + "model.diffusion_model.layer1.weight": fp8_weight, + "model.diffusion_model.layer1.scale_weight": torch.tensor(1.0), + } + result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") + # Should return None to trigger legacy path + self.assertIsNone(result) + + def test_metadata_format(self): + """Test with new metadata format""" + metadata = { + "format_version": "1.0", + "layers": { + "layer1": { + "format": "fp8_e4m3fn_scaled", + "params": {"use_fp8_matmul": True} + }, + "layer2": { + "format": "fp8_e5m2_scaled", + "params": {"use_fp8_matmul": True} + } + } + } + state_dict = { + "model.diffusion_model._quantization_metadata": metadata, + "model.diffusion_model.layer1.weight": torch.randn(10, 20), + } + result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") + + self.assertIsNotNone(result) + self.assertIn("layer1", result) + self.assertIn("layer2", result) + self.assertEqual(result["layer1"]["format"], "fp8_e4m3fn_scaled") + self.assertEqual(result["layer2"]["format"], "fp8_e5m2_scaled") + # Metadata should be popped from state_dict + self.assertNotIn("model.diffusion_model._quantization_metadata", state_dict) + + def test_mixed_precision_detection(self): + """Test detection of mixed precision via scale patterns""" + # Create FP8 tensors by converting from float32 + fp8_weight1 = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + fp8_weight3 = torch.randn(30, 40, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict = { + # Layer 1: FP8 (has scale_weight) + "model.diffusion_model.layer1.weight": fp8_weight1, + "model.diffusion_model.layer1.scale_weight": torch.tensor(1.0), + # Layer 2: Standard (no scale_weight) + "model.diffusion_model.layer2.weight": torch.randn(20, 30, dtype=torch.bfloat16), + # Layer 3: FP8 (has scale_weight) + "model.diffusion_model.layer3.weight": fp8_weight3, + "model.diffusion_model.layer3.scale_weight": torch.tensor(1.0), + } + result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") + + self.assertIsNotNone(result) + self.assertIn("layer1", result) + self.assertIn("layer3", result) + self.assertNotIn("layer2", result) # Layer 2 not quantized + self.assertEqual(result["layer1"]["format"], "fp8_e4m3fn_scaled") + self.assertEqual(result["layer3"]["format"], "fp8_e4m3fn_scaled") + + def test_all_layers_quantized(self): + """Test that uniform quantization (all layers) returns None""" + # Create FP8 tensors by converting from float32 + fp8_weight1 = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + fp8_weight2 = torch.randn(20, 30, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict = { + # All layers have scale_weight + "model.diffusion_model.layer1.weight": fp8_weight1, + "model.diffusion_model.layer1.scale_weight": torch.tensor(1.0), + "model.diffusion_model.layer2.weight": fp8_weight2, + "model.diffusion_model.layer2.scale_weight": torch.tensor(1.0), + } + result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") + + # If all layers are quantized, it's not mixed precision + # Should return None to use legacy scaled_fp8_ops path + self.assertIsNone(result) + + def test_fp8_e5m2_detection(self): + """Test detection of FP8 E5M2 format""" + # Create FP8 E5M2 tensor by converting from float32 + fp8_weight = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e5m2) + state_dict = { + "model.diffusion_model.layer1.weight": fp8_weight, + "model.diffusion_model.layer1.scale_weight": torch.tensor(1.0), + "model.diffusion_model.layer2.weight": torch.randn(20, 30, dtype=torch.bfloat16), + } + result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") + + self.assertIsNotNone(result) + self.assertIn("layer1", result) + self.assertEqual(result["layer1"]["format"], "fp8_e5m2_scaled") + + def test_invalid_metadata(self): + """Test with invalid metadata format""" + state_dict = { + "model.diffusion_model._quantization_metadata": "invalid_string", + "model.diffusion_model.layer1.weight": torch.randn(10, 20), + } + result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") + # Should return None on invalid metadata + self.assertIsNone(result) + + def test_different_prefix(self): + """Test with different model prefix (audio model)""" + # Create FP8 tensor by converting from float32 + fp8_weight = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict = { + "model.model.layer1.weight": fp8_weight, + "model.model.layer1.scale_weight": torch.tensor(1.0), + "model.model.layer2.weight": torch.randn(20, 30, dtype=torch.bfloat16), + } + result = model_detection.detect_layer_quantization(state_dict, "model.model.") + + self.assertIsNotNone(result) + self.assertIn("layer1", result) + + +class TestPickOperationsIntegration(unittest.TestCase): + """Test pick_operations with model_config parameter""" + + def test_backward_compatibility(self): + """Test that pick_operations works without model_config (legacy)""" + from comfy import ops + + # Should work without model_config parameter + result = ops.pick_operations(None, None) + self.assertIsNotNone(result) + self.assertEqual(result, ops.disable_weight_init) + + def test_with_model_config_no_quant(self): + """Test with model_config but no quantization""" + from comfy import ops, supported_models_base + + model_config = supported_models_base.BASE({}) + model_config.layer_quant_config = None + + result = ops.pick_operations(None, None, model_config=model_config) + self.assertIsNotNone(result) + # Should use standard path + self.assertEqual(result, ops.disable_weight_init) + + def test_legacy_scaled_fp8(self): + """Test that legacy scaled_fp8 still works""" + from comfy import ops, supported_models_base + + model_config = supported_models_base.BASE({}) + model_config.scaled_fp8 = torch.float8_e4m3fn + + result = ops.pick_operations( + None, None, + scaled_fp8=torch.float8_e4m3fn, + model_config=model_config + ) + self.assertIsNotNone(result) + # Should return scaled_fp8_ops (the returned class is the inner class) + # Check that it's not the standard disable_weight_init + self.assertNotEqual(result, ops.disable_weight_init) + # Verify it has Linear class + self.assertTrue(hasattr(result, 'Linear')) + + +if __name__ == "__main__": + unittest.main() + diff --git a/tests-unit/comfy_test/test_quant_registry.py b/tests-unit/comfy_test/test_quant_registry.py new file mode 100644 index 000000000..5c624b1db --- /dev/null +++ b/tests-unit/comfy_test/test_quant_registry.py @@ -0,0 +1,399 @@ +""" +Unit tests for tensor subclass quantization system. +Tests the new QuantizedTensorFP8 subclass and operation handlers. +""" + +import unittest +import torch +import sys +import os + +# Add comfy to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +from comfy import ops +from comfy import quant_ops + + +class TestQuantizedTensorFP8(unittest.TestCase): + """Test the QuantizedTensorFP8 tensor subclass""" + + def test_creation(self): + """Test creating a QuantizedTensorFP8""" + fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(2.0) + + qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.bfloat16) + + self.assertIsInstance(qt, quant_ops.QuantizedTensorFP8) + self.assertEqual(qt.shape, (256, 128)) + self.assertEqual(qt.dtype, torch.float8_e4m3fn) + self.assertEqual(qt._scale, scale) + self.assertEqual(qt._orig_dtype, torch.bfloat16) + + def test_dequantize(self): + """Test explicit dequantization""" + # Create a simple FP8 tensor + fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(3.0) + + qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.float32) + dequantized = qt.dequantize() + + # Dequantized should be approximately ones * 3.0 + self.assertEqual(dequantized.dtype, torch.float32) + self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) + + def test_repr(self): + """Test string representation""" + fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(2.5) + + qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.bfloat16) + repr_str = repr(qt) + + self.assertIn("QuantizedTensorFP8", repr_str) + self.assertIn("shape", repr_str) + self.assertIn("scale", repr_str) + + +class TestOperationRegistry(unittest.TestCase): + """Test the operation registry system""" + + def test_registry_basics(self): + """Test that operations are registered""" + registered_ops = quant_ops.list_registered_ops() + + # Check that key operations are registered + self.assertIn(torch.ops.aten.linear.default, registered_ops) + self.assertIn(torch.ops.aten.silu.default, registered_ops) + self.assertIn(torch.ops.aten.layer_norm.default, registered_ops) + self.assertIn(torch.ops.aten.add.Tensor, registered_ops) + self.assertIn(torch.ops.aten.mul.Tensor, registered_ops) + + def test_get_handler(self): + """Test getting a registered handler""" + handler = quant_ops.get_quant_handler(torch.ops.aten.linear.default) + self.assertIsNotNone(handler) + self.assertTrue(callable(handler)) + + def test_custom_registration(self): + """Test registering a custom operation""" + + # Define a custom handler + @quant_ops.register_quant_op(torch.ops.aten.relu.default) + def custom_relu_handler(func, args, kwargs): + return func(*args, **kwargs) + + # Verify registration + handler = quant_ops.get_quant_handler(torch.ops.aten.relu.default) + self.assertIsNotNone(handler) + self.assertEqual(handler, custom_relu_handler) + + +class TestLinearHandler(unittest.TestCase): + """Test the linear operation handler""" + + def test_linear_with_quantized_weight(self): + """Test F.linear with quantized weight""" + # Set seed for reproducibility + torch.manual_seed(42) + + # Create quantized weight + weight_fp32 = torch.randn(256, 128, dtype=torch.float32) + scale = torch.tensor(2.0) + weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn) + weight_q = quant_ops.QuantizedTensorFP8(weight_fp8, scale, orig_dtype=torch.float32) + + # Create input + input_tensor = torch.randn(16, 128, dtype=torch.float32) + + # Call linear (should trigger dispatch) + output = torch.nn.functional.linear(input_tensor, weight_q, bias=None) + + # Verify output shape + self.assertEqual(output.shape, (16, 256)) + + # Verify it's approximately correct (allowing for FP8 quantization error) + # Note: FP8 has limited precision, so use very loose tolerance + expected = torch.nn.functional.linear(input_tensor, weight_fp32, bias=None) + # Just check that it's in the right ballpark (within 50% error on average) + mean_rel_error = ((output - expected).abs() / (expected.abs() + 1e-6)).mean() + self.assertLess(mean_rel_error, 0.5, f"Mean relative error {mean_rel_error:.3f} is too large") + + def test_linear_with_bias(self): + """Test F.linear with quantized weight and bias""" + weight_fp32 = torch.randn(64, 32, dtype=torch.float32) + scale = torch.tensor(1.5) + weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn) + weight_q = quant_ops.QuantizedTensorFP8(weight_fp8, scale, orig_dtype=torch.float32) + + input_tensor = torch.randn(8, 32, dtype=torch.float32) + bias = torch.randn(64, dtype=torch.float32) + + output = torch.nn.functional.linear(input_tensor, weight_q, bias) + + self.assertEqual(output.shape, (8, 64)) + + +class TestActivationHandlers(unittest.TestCase): + """Test activation function handlers""" + + def test_silu_with_quantized_input(self): + """Test SiLU with quantized input""" + # Create quantized input + input_fp32 = torch.randn(16, 128, dtype=torch.float32) + scale = torch.tensor(1.0) + input_fp8 = (input_fp32 / scale).to(torch.float8_e4m3fn) + input_q = quant_ops.QuantizedTensorFP8(input_fp8, scale, orig_dtype=torch.float32) + + # Apply SiLU + output = torch.nn.functional.silu(input_q) + + # Should return a QuantizedTensorFP8 + self.assertIsInstance(output, quant_ops.QuantizedTensorFP8) + + # Verify approximate correctness + expected = torch.nn.functional.silu(input_fp32) + output_dq = output.dequantize() + self.assertTrue(torch.allclose(output_dq, expected, rtol=0.2, atol=0.2)) + + def test_layernorm_dequantizes(self): + """Test that LayerNorm dequantizes input""" + # Create quantized input + input_fp32 = torch.randn(16, 128, dtype=torch.float32) + scale = torch.tensor(1.0) + input_fp8 = (input_fp32 / scale).to(torch.float8_e4m3fn) + input_q = quant_ops.QuantizedTensorFP8(input_fp8, scale, orig_dtype=torch.float32) + + # Apply LayerNorm + weight = torch.ones(128) + bias = torch.zeros(128) + output = torch.nn.functional.layer_norm(input_q, (128,), weight, bias) + + # Should NOT be quantized (LayerNorm breaks quantization) + self.assertNotIsInstance(output, quant_ops.QuantizedTensorFP8) + self.assertEqual(output.dtype, torch.float32) + + +class TestElementwiseHandlers(unittest.TestCase): + """Test element-wise operation handlers""" + + def test_add_mixed_tensors(self): + """Test addition with mixed quantized/non-quantized tensors""" + # Create quantized tensor + a_fp32 = torch.ones(10, 20, dtype=torch.float32) + scale = torch.tensor(1.0) + a_fp8 = (a_fp32 / scale).to(torch.float8_e4m3fn) + a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale, orig_dtype=torch.float32) + + # Non-quantized tensor + b = torch.ones(10, 20, dtype=torch.float32) * 2.0 + + # Add them + result = a_q + b + + # Should be dequantized + self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8) + self.assertTrue(torch.allclose(result, torch.ones(10, 20) * 3.0, rtol=0.1)) + + def test_mul_quantized_tensors(self): + """Test multiplication of two quantized tensors""" + a_fp32 = torch.ones(10, 20) * 2.0 + scale_a = torch.tensor(1.0) + a_fp8 = (a_fp32 / scale_a).to(torch.float8_e4m3fn) + a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale_a, orig_dtype=torch.float32) + + b_fp32 = torch.ones(10, 20) * 3.0 + scale_b = torch.tensor(1.0) + b_fp8 = (b_fp32 / scale_b).to(torch.float8_e4m3fn) + b_q = quant_ops.QuantizedTensorFP8(b_fp8, scale_b, orig_dtype=torch.float32) + + result = a_q * b_q + + # Should be dequantized + self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8) + self.assertTrue(torch.allclose(result, torch.ones(10, 20) * 6.0, rtol=0.1)) + + +class TestFallbackMechanism(unittest.TestCase): + """Test fallback for unsupported operations""" + + def test_unsupported_op_dequantizes(self): + """Test that unsupported operations fall back to dequantization""" + # Set seed for reproducibility + torch.manual_seed(42) + + # Create quantized tensor + a_fp32 = torch.randn(10, 20, dtype=torch.float32) + scale = torch.tensor(1.0) + a_fp8 = (a_fp32 / scale).to(torch.float8_e4m3fn) + a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale, orig_dtype=torch.float32) + + # Call an operation that doesn't have a registered handler + # For example, torch.abs + result = torch.abs(a_q) + + # Should work via fallback (dequantize → abs → return) + self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8) + expected = torch.abs(a_fp32) + # FP8 introduces quantization error, so use loose tolerance + mean_error = (result - expected).abs().mean() + self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large") + + +class TestMixedPrecisionOps(unittest.TestCase): + """Test MixedPrecisionOps integration""" + + def test_linear_layer_creation(self): + """Test that MixedPrecisionOps.Linear can be created""" + layer = ops.MixedPrecisionOps.Linear(128, 256, bias=True, device="cpu", dtype=torch.float32) + + self.assertIsInstance(layer, ops.MixedPrecisionOps.Linear) + self.assertFalse(layer._quantization_initialized) + self.assertIsNone(layer.quant_format) + + def test_layer_quant_config_detection(self): + """Test that layer quantization config is detected during load""" + # Set up layer config + ops.MixedPrecisionOps._layer_quant_config = { + "test_layer": { + "format": "fp8_e4m3fn", + "params": {} + } + } + + # Create a state dict with quantized weight + weight_fp32 = torch.randn(256, 128, dtype=torch.float32) + scale = torch.tensor(2.0) + weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn) + + state_dict = { + "model.diffusion_model.test_layer.weight": weight_fp8, + "model.diffusion_model.test_layer.scale_weight": scale, + } + + # Create layer and load + layer = ops.MixedPrecisionOps.Linear(128, 256, bias=False, device="cpu", dtype=torch.float8_e4m3fn) + layer.weight = torch.nn.Parameter(torch.zeros(256, 128, dtype=torch.float8_e4m3fn)) + + # Manually call _load_from_state_dict + layer._load_from_state_dict( + state_dict, + prefix="model.diffusion_model.test_layer.", + local_metadata={}, + strict=True, + missing_keys=[], + unexpected_keys=[], + error_msgs=[] + ) + + # Verify quantization was initialized + self.assertTrue(layer._quantization_initialized) + self.assertEqual(layer.quant_format, "fp8_e4m3fn") + self.assertIsNotNone(layer.quant_scale) + + # Verify weight is wrapped + self.assertIsInstance(layer.weight.data, quant_ops.QuantizedTensorFP8) + + # Clean up + ops.MixedPrecisionOps._layer_quant_config = {} + + +class TestBackwardCompatibility(unittest.TestCase): + """Test backward compatibility with legacy systems""" + + def test_legacy_ops_classes_exist(self): + """Test that legacy ops classes still exist""" + self.assertTrue(hasattr(ops, 'disable_weight_init')) + self.assertTrue(hasattr(ops, 'manual_cast')) + self.assertTrue(hasattr(ops, 'fp8_ops')) + self.assertTrue(hasattr(ops, 'scaled_fp8_ops')) + + def test_pick_operations_legacy_path(self): + """Test pick_operations returns correct class for legacy cases""" + # Test standard case + result = ops.pick_operations(torch.float32, torch.float32) + self.assertEqual(result, ops.disable_weight_init) + + # Test manual cast case + result = ops.pick_operations(torch.float32, torch.float16) + self.assertEqual(result, ops.manual_cast) + + +class TestFP8LinearUnification(unittest.TestCase): + """Test that fp8_linear now uses the unified tensor subclass infrastructure""" + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA required for FP8") + def test_fp8_linear_uses_tensor_subclass(self): + """Verify fp8_linear wraps tensors in QuantizedTensorFP8""" + torch.manual_seed(42) + + # Create a mock Linear layer with FP8 weight + linear = ops.fp8_ops.Linear(4, 3, bias=True) + linear.weight = torch.nn.Parameter( + torch.randn(3, 4, dtype=torch.bfloat16).to(torch.float8_e4m3fn), + requires_grad=False + ) + linear.bias = torch.nn.Parameter( + torch.randn(3, dtype=torch.bfloat16), + requires_grad=False + ) + linear.scale_weight = torch.tensor(1.0) + linear.scale_input = None # No input scaling + + # Create input + input_tensor = torch.randn(2, 4, dtype=torch.bfloat16) + + # Call fp8_linear - should work without errors + try: + result = ops.fp8_linear(linear, input_tensor) + self.assertIsNotNone(result) + self.assertEqual(result.shape, (2, 3)) + except Exception as e: + # On CPU or unsupported hardware, _scaled_mm might not be available + # but the function should still complete without syntax errors + pass + + def test_fp8_linear_maintains_signature(self): + """Verify fp8_linear maintains its original function signature""" + import inspect + sig = inspect.signature(ops.fp8_linear) + params = list(sig.parameters.keys()) + + # Should have 'self' and 'input' parameters + self.assertIn('self', params) + self.assertIn('input', params) + self.assertEqual(len(params), 2) + + def test_fp8_linear_returns_none_for_non_fp8(self): + """Verify fp8_linear returns None for non-FP8 weights""" + # Create a Linear layer with BF16 weight (not FP8) + linear = ops.disable_weight_init.Linear(4, 3, bias=False) + linear.weight = torch.nn.Parameter( + torch.randn(3, 4, dtype=torch.bfloat16), + requires_grad=False + ) + + input_tensor = torch.randn(2, 4, dtype=torch.bfloat16) + + # Should return None for non-FP8 weights + result = ops.fp8_linear(linear, input_tensor) + self.assertIsNone(result) + + def test_fp8_ops_linear_uses_fp8_linear(self): + """Verify fp8_ops.Linear still uses fp8_linear in forward pass""" + linear = ops.fp8_ops.Linear(4, 3, bias=False) + + # Verify the class has the forward_comfy_cast_weights method + self.assertTrue(hasattr(linear, 'forward_comfy_cast_weights')) + + # The forward_comfy_cast_weights should attempt to call fp8_linear + # (we can't easily test this without mocking, but we verify structure) + import inspect + source = inspect.getsource(linear.forward_comfy_cast_weights) + self.assertIn('fp8_linear', source) + + +if __name__ == "__main__": + unittest.main() From 7ea731ea98445f12c07807102a1f2a4350952786 Mon Sep 17 00:00:00 2001 From: lspindler Date: Wed, 22 Oct 2025 11:25:39 +0200 Subject: [PATCH 3/7] Fix FP8 MM --- comfy/ops.py | 14 +--- comfy/quant_ops.py | 205 +++++++++++---------------------------------- 2 files changed, 48 insertions(+), 171 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 2e6782dbd..060b35137 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -370,19 +370,7 @@ def fp8_linear(self, input): # Wrap weight in QuantizedTensorFP8 - this enables unified dispatch quantized_weight = QuantizedTensorFP8(w, scale_weight, orig_dtype=input_dtype) - - # Handle input quantization and wrapping - if self.scale_input is None: - # Clamp input to FP8 range and quantize - input = torch.clamp(input, min=-448, max=448, out=input) - input_fp8 = input.reshape(-1, input_shape[2]).to(dtype).contiguous() - else: - # Apply inverse scale and quantize - input_fp8 = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous() - - # Wrap input in QuantizedTensorFP8 - quantized_input = QuantizedTensorFP8(input_fp8, scale_input, orig_dtype=input_dtype) - + quantized_input = QuantizedTensorFP8.quantize(input.reshape(-1, input_shape[2]), scale_input, fp8_dtype=dtype) # Call F.linear - __torch_dispatch__ routes to handle_linear_fp8 in quant_ops.py! # This is the key unification: all FP8 computation goes through one path o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 681eb9134..8e3bacbaf 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -79,18 +79,47 @@ class QuantizedTensorFP8(torch.Tensor): self._scale = scale self._orig_dtype = orig_dtype # Store a reference to prevent infinite recursion in dequantize - self._raw_data = tensor + self._raw_data = tensor.contiguous() def __repr__(self): return (f"QuantizedTensorFP8(shape={self.shape}, " f"scale={self._scale:.4f}, dtype={self._orig_dtype})") + @classmethod + def quantize(cls, tensor, scale, fp8_dtype=torch.float8_e4m3fn): + orig_dtype = tensor.dtype + + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32) + + tensor_fp8 = None + if _CK_AVAILABLE: + try: + tensor_fp8 = ck.quantize_per_tensor_fp8(tensor, scale, fp8_dtype) + except Exception as e: + logging.debug(f"comfy_kitchen quantization failed, using PyTorch: {e}") + + if tensor_fp8 is None: + lp_amax = torch.finfo(fp8_dtype).max + tensor_scaled = tensor.float() / scale + torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) + tensor_fp8 = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format) + + return cls(tensor_fp8, scale, orig_dtype=orig_dtype) + + @classmethod + def quantize_dynamic(cls, tensor, strategy="amax", fp8_dtype=torch.float8_e4m3fn): + if strategy == "amax": + scale = torch.amax(tensor) / torch.finfo(fp8_dtype).max + scale = scale.to(tensor.device, dtype=torch.float32) + else: + raise ValueError(f"Unknown quantization strategy: {strategy}. " + f"Supported: 'amax'") + + return cls.quantize(tensor, scale, fp8_dtype=fp8_dtype) + @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - """ - Intercept ALL torch operations. - Routes to registered handlers or falls back to dequantization. - """ kwargs = kwargs or {} # Special case: skip dispatch for internal tensor operations @@ -134,16 +163,11 @@ class QuantizedTensorFP8(torch.Tensor): return func(*new_args, **new_kwargs) def dequantize(self) -> torch.Tensor: - """Explicit dequantization""" - # Use the raw data and convert directly - # Call aten ops directly to minimize dispatch interference plain_tensor = torch.ops.aten._to_copy.default(self._raw_data, dtype=self._orig_dtype) - # Multiply by scale return plain_tensor * self._scale def detach(self): """Detach returns a new QuantizedTensorFP8 (required for Parameter)""" - # Detach the raw data and create a new QuantizedTensorFP8 detached_data = self._raw_data.detach() return QuantizedTensorFP8(detached_data, self._scale, self._orig_dtype) @@ -165,48 +189,35 @@ def handle_linear_fp8(func, args, kwargs): input_tensor = args[0] weight = args[1] bias = args[2] if len(args) > 2 else None - + out_dtype = kwargs.get("out_dtype", input_tensor._orig_dtype) + # Case 1: Both input and weight are FP8 if isinstance(input_tensor, QuantizedTensorFP8) and isinstance(weight, QuantizedTensorFP8): - # Use _scaled_mm for FP8×FP8 matmul # Get plain tensors to avoid dispatch recursion plain_input = input_tensor._raw_data plain_weight = weight._raw_data - weight_t = plain_weight.t().contiguous() + weight_t = plain_weight.t() # Keep as column-major for cuBLASLt try: - if bias is not None: - output = torch._scaled_mm( - plain_input, - weight_t, - out_dtype=input_tensor._orig_dtype, - bias=bias, - scale_a=input_tensor._scale, - scale_b=weight._scale - ) - else: - output = torch._scaled_mm( - plain_input, - weight_t, - out_dtype=input_tensor._orig_dtype, - scale_a=input_tensor._scale, - scale_b=weight._scale - ) - + output = torch._scaled_mm( + plain_input, + weight_t, + bias=bias, + scale_a=input_tensor._scale, + scale_b=weight._scale, + out_dtype=out_dtype, + ) if isinstance(output, tuple): output = output[0] - # Check if output is FP8 (some architectures support this) if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - # Keep quantized! output_scale = input_tensor._scale * weight._scale - return QuantizedTensorFP8(output, output_scale, input_tensor._orig_dtype) + return QuantizedTensorFP8(output, output_scale, input_tensor._orig_dtype) # TODO is this correct? Can't cuBLAS return it else: return output except Exception as e: logging.debug(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") - # Fall through to dequantization path - + # Case 2: Only weight is quantized if isinstance(weight, QuantizedTensorFP8): weight_dq = weight.dequantize() @@ -222,125 +233,3 @@ def handle_linear_fp8(func, args, kwargs): else: return torch.nn.functional.linear(input_tensor, weight, bias) - -@register_quant_op(torch.ops.aten.silu.default) -def handle_silu_fp8(func, args, kwargs): - """ - SiLU can be computed approximately on FP8. - Keeps activations quantized for next layer. - """ - input_q = args[0] - - if not isinstance(input_q, QuantizedTensorFP8): - # Not quantized, use standard path - return torch.nn.functional.silu(input_q) - - # Compute SiLU while keeping quantized - # SiLU(x) = x * sigmoid(x) - - # Get plain tensor to avoid dispatch recursion - plain_tensor = input_q._raw_data - - # Upcast to FP16 for sigmoid stability - x_fp16 = plain_tensor.to(torch.float16) - sigmoid_fp16 = torch.sigmoid(x_fp16 * input_q._scale) - result_fp16 = x_fp16 * sigmoid_fp16 - - # Convert back to FP8 - result_fp8 = result_fp16.to(plain_tensor.dtype) - - # Return quantized (scale approximately preserved) - return QuantizedTensorFP8(result_fp8, input_q._scale, input_q._orig_dtype) - - -@register_quant_op(torch.ops.aten.layer_norm.default) -def handle_layernorm_fp8(func, args, kwargs): - """ - LayerNorm requires high precision. - Dequantizes input and returns standard tensor. - """ - input_q = args[0] - normalized_shape = args[1] - weight = args[2] if len(args) > 2 else None - bias = args[3] if len(args) > 3 else None - eps = args[4] if len(args) > 4 else 1e-5 - - # Dequantize if needed - if isinstance(input_q, QuantizedTensorFP8): - x = input_q.dequantize() - else: - x = input_q - - # Standard LayerNorm - result = torch.nn.functional.layer_norm(x, normalized_shape, weight, bias, eps) - - # Return dequantized (next layer will quantize if needed) - return result - - -@register_quant_op(torch.ops.aten.group_norm.default) -def handle_groupnorm_fp8(func, args, kwargs): - """ - GroupNorm requires high precision. - Dequantizes input and returns standard tensor. - """ - input_q = args[0] - num_groups = args[1] - weight = args[2] if len(args) > 2 else None - bias = args[3] if len(args) > 3 else None - eps = args[4] if len(args) > 4 else 1e-5 - - # Dequantize if needed - if isinstance(input_q, QuantizedTensorFP8): - x = input_q.dequantize() - else: - x = input_q - - # Standard GroupNorm - result = torch.nn.functional.group_norm(x, num_groups, weight, bias, eps) - - # Return dequantized - return result - - -@register_quant_op(torch.ops.aten.add.Tensor) -def handle_add_fp8(func, args, kwargs): - """ - Handle addition with mixed quantized/non-quantized tensors. - """ - a = args[0] - b = args[1] - - # If both are quantized, dequantize both - if isinstance(a, QuantizedTensorFP8) and isinstance(b, QuantizedTensorFP8): - return a.dequantize() + b.dequantize() - # If only one is quantized, dequantize it - elif isinstance(a, QuantizedTensorFP8): - return a.dequantize() + b - elif isinstance(b, QuantizedTensorFP8): - return a + b.dequantize() - # Neither is quantized - else: - return a + b - - -@register_quant_op(torch.ops.aten.mul.Tensor) -def handle_mul_fp8(func, args, kwargs): - """ - Handle multiplication with mixed quantized/non-quantized tensors. - """ - a = args[0] - b = args[1] - - # If both are quantized, dequantize both - if isinstance(a, QuantizedTensorFP8) and isinstance(b, QuantizedTensorFP8): - return a.dequantize() * b.dequantize() - # If only one is quantized, dequantize it - elif isinstance(a, QuantizedTensorFP8): - return a.dequantize() * b - elif isinstance(b, QuantizedTensorFP8): - return a * b.dequantize() - # Neither is quantized - else: - return a * b - From 5e9f33575367f0d6ac897665ed6dd99616b33d28 Mon Sep 17 00:00:00 2001 From: lspindler Date: Fri, 24 Oct 2025 14:44:54 +0200 Subject: [PATCH 4/7] An actually functional POC --- comfy/model_detection.py | 123 +--------- comfy/ops.py | 280 ++++++++------------- comfy/quant_ops.py | 508 ++++++++++++++++++++++++++++----------- comfy/sd.py | 8 +- 4 files changed, 481 insertions(+), 438 deletions(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 01f26836b..ffb1885fd 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -7,121 +7,24 @@ import logging import torch -# ============================================================================== -# Quantization Detection Functions -# ============================================================================== - -def normalize_layer_name(full_key, known_prefixes): - """ - Strip model prefix and parameter suffix from a state dict key. - - Args: - full_key: Full state dict key (e.g., "model.diffusion_model.layer1.weight") - known_prefixes: List of known model prefixes to strip - - Returns: - Normalized layer name (e.g., "layer1") - """ - name = full_key - - # Strip model prefix - for prefix in known_prefixes: - if name.startswith(prefix): - name = name[len(prefix):] - break - - # Remove parameter suffix - for suffix in [".weight", ".bias", ".scale_weight", ".scale_input"]: - if name.endswith(suffix): - name = name[:-len(suffix)] - break - - return name - - -def detect_layer_quantization(state_dict, prefix="model.diffusion_model."): - """ - Detect per-layer quantization configuration from state dict. - - Detection priority: - 1. Check for _quantization_metadata key (new format) - 2. Check for scaled_fp8 key (legacy format - return None) - 3. Check for per-layer scale_weight patterns (mixed detection) - 4. No quantization detected (return None) - - Args: - state_dict: Model state dictionary - prefix: Key prefix for model layers - - Returns: - Dict mapping layer names to quantization configs, or None for legacy/no quantization. - - Example return value: - { - "input_blocks.5.1.transformer_blocks.0.attn1.to_q": { - "format": "fp8_e4m3fn_scaled", - "params": {"use_fp8_matmul": True} - }, - "middle_block.1.transformer_blocks.0.attn2.to_k": { - "format": "fp8_e5m2_scaled", - "params": {"use_fp8_matmul": True} - } - } - """ - - # 1. Check for new metadata format - metadata_key = f"{prefix}_quantization_metadata" - if metadata_key in state_dict: - try: - metadata = state_dict.pop(metadata_key) - if isinstance(metadata, dict) and "layers" in metadata: - logging.info(f"Found quantization metadata (version {metadata.get('format_version', 'unknown')})") - return metadata["layers"] - else: - logging.warning(f"Invalid quantization metadata format, ignoring") - except Exception as e: - logging.error(f"Failed to parse quantization metadata: {e}") - return None +def detect_layer_quantization(state_dict, metadata, prefix="model.diffusion_model."): + # 1. Check for per-layer config in metadata + quant_key = "_quantization_metadata" + if metadata is not None and quant_key in metadata: + quant_metadata = metadata.pop(quant_key) + quant_metadata = json.loads(quant_metadata) + if isinstance(quant_metadata, dict) and "layers" in quant_metadata: + logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})") + return quant_metadata["layers"] + else: + raise ValueError(f"Invalid quantization metadata format") # 2. Check for legacy scaled_fp8 marker - # If present, return None to use legacy code path scaled_fp8_key = f"{prefix}scaled_fp8" if scaled_fp8_key in state_dict: logging.debug("Detected legacy scaled_fp8 format, using legacy code path") return None - - # 3. Check for per-layer scale patterns (mixed precision without metadata) - # Look for layers that have scale_weight but not all layers have it - known_prefixes = [prefix] - layer_configs = {} - layers_with_scale = set() - layers_with_weight = set() - - for key in state_dict.keys(): - if key.startswith(prefix): - if key.endswith(".scale_weight"): - layer_name = normalize_layer_name(key, known_prefixes) - layers_with_scale.add(layer_name) - # Detect format based on weight dtype - weight_key = f"{prefix}{layer_name}.weight" - if weight_key in state_dict: - weight_dtype = state_dict[weight_key].dtype - if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - format_name = "fp8_e4m3fn" if weight_dtype == torch.float8_e4m3fn else "fp8_e5m2" - layer_configs[layer_name] = { - "format": format_name, - "params": {} - } - elif key.endswith(".weight") and not key.endswith(".scale_weight"): - layer_name = normalize_layer_name(key, known_prefixes) - layers_with_weight.add(layer_name) - - # If we found scale_weight on some but not all layers, it's mixed precision - if layer_configs and len(layers_with_scale) < len(layers_with_weight): - logging.info(f"Detected mixed precision via scale patterns: {len(layers_with_scale)} quantized layers, {len(layers_with_weight)} total layers") - return layer_configs - - # 4. No quantization detected + return None @@ -821,7 +724,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal model_config.optimizations["fp8"] = True # Detect per-layer quantization (mixed precision) - layer_quant_config = detect_layer_quantization(state_dict, unet_key_prefix) + layer_quant_config = detect_layer_quantization(state_dict, metadata, unet_key_prefix) if layer_quant_config: model_config.layer_quant_config = layer_quant_config logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized") diff --git a/comfy/ops.py b/comfy/ops.py index 060b35137..8d11aeefc 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -325,19 +325,8 @@ class manual_cast(disable_weight_init): def fp8_linear(self, input): """ - Legacy FP8 linear function - now uses tensor subclass infrastructure. - - This function maintains backward compatibility with existing code while - routing all FP8 computation through the unified tensor subclass system. - All actual FP8 matmul logic is handled by the registered operation handlers - in quant_ops.py via __torch_dispatch__. - - Args: - self: Linear layer with FP8 weight and scale parameters - input: Input tensor (any dtype) - - Returns: - Output tensor or None if weight is not FP8 + Legacy FP8 linear function for backward compatibility. + Uses QuantizedTensor subclass for dispatch. """ dtype = self.weight.dtype if dtype not in [torch.float8_e4m3fn]: @@ -352,10 +341,8 @@ def fp8_linear(self, input): input_dtype = input.dtype if len(input.shape) == 3: - # Get weight and bias using standard casting w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) - # Get scales (same as before) scale_weight = self.scale_weight scale_input = self.scale_input if scale_weight is None: @@ -368,14 +355,13 @@ def fp8_linear(self, input): else: scale_input = scale_input.to(input.device) - # Wrap weight in QuantizedTensorFP8 - this enables unified dispatch - quantized_weight = QuantizedTensorFP8(w, scale_weight, orig_dtype=input_dtype) - quantized_input = QuantizedTensorFP8.quantize(input.reshape(-1, input_shape[2]), scale_input, fp8_dtype=dtype) - # Call F.linear - __torch_dispatch__ routes to handle_linear_fp8 in quant_ops.py! - # This is the key unification: all FP8 computation goes through one path + # Wrap weight in QuantizedTensor - this enables unified dispatch + # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! + layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype} + quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight) + quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, fp8_dtype=dtype) o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) - # Reshape output if tensor_2d: return o.reshape(input_shape[0], -1) return o.reshape((-1, input_shape[1], self.weight.shape[0])) @@ -472,183 +458,117 @@ if CUBLAS_IS_AVAILABLE: return super().forward(*args, **kwargs) -# Import quantization operations from separate module -from .quant_ops import QuantizedTensorFP8 - - # ============================================================================== # Mixed Precision Operations # ============================================================================== +from .quant_ops import QuantizedTensor, TensorCoreFP8Layout + +QUANT_FORMAT_MIXINS = { + "float8_e4m3fn": { + "dtype": torch.float8_e4m3fn, + "layout_type": TensorCoreFP8Layout, + "parameters": { + "weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False), + "input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False), + } + } +} class MixedPrecisionOps(disable_weight_init): - """ - Operations class supporting per-layer quantization (mixed precision). - - This class enables different layers to use different quantization formats - within the same model (e.g., some layers FP8, others BF16). - - Layer-specific quantization is configured via _layer_quant_config class variable, - which is set by pick_operations() when a model has mixed precision. - """ - - _layer_quant_config = {} # Class variable set by pick_operations() - - class Linear(disable_weight_init.Linear): - """Linear layer with optional per-layer quantization using tensor subclasses""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.quant_format = None - self.quant_scale = None - self._quantization_initialized = False - + _layer_quant_config = {} + _compute_dtype = torch.bfloat16 + + class Linear(torch.nn.Module, CastWeightBiasOp): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + super().__init__() + + self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} + # self.factory_kwargs = {"device": device, "dtype": dtype} + + self.in_features = in_features + self.out_features = out_features + if bias: + self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs)) + else: + self.register_parameter("bias", None) + + self.tensor_class = None + def reset_parameters(self): - # Don't allocate weights - return None like disable_weight_init return None def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - """ - Called by PyTorch during load_state_dict. - Load weight and wrap in QuantizedTensorFP8 if this layer is quantized. - """ - # Call parent to load weight and bias first - super()._load_from_state_dict( - state_dict, prefix, local_metadata, - strict, missing_keys, unexpected_keys, error_msgs - ) - - # After weight is loaded, wrap it if this layer is quantized - if not self._quantization_initialized: - # Normalize layer name from prefix - layer_name = prefix.rstrip('.') - - # Strip known model prefixes - for model_prefix in ["model.diffusion_model.", "model.model.", "net."]: - if layer_name.startswith(model_prefix): - layer_name = layer_name[len(model_prefix):] - break - - # Check if this layer has quantization config - if layer_name in MixedPrecisionOps._layer_quant_config: - config = MixedPrecisionOps._layer_quant_config[layer_name] - self.quant_format = config.get("format", "fp8_e4m3fn") - - # Load scale parameter - scale_key = f"{prefix}scale_weight" - if scale_key in state_dict: - self.quant_scale = state_dict[scale_key] - - # Wrap weight in QuantizedTensorFP8 - if self.weight is not None and self.weight.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - try: - # Determine original dtype (default to bfloat16) - orig_dtype = torch.bfloat16 - - # Wrap weight in quantized tensor subclass - quantized_weight = QuantizedTensorFP8( - self.weight.data, - self.quant_scale, - orig_dtype=orig_dtype - ) - - # Replace weight parameter with wrapped version - self.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) - - logging.debug(f"Wrapped layer {layer_name} weight in QuantizedTensorFP8 (format: {self.quant_format})") - except Exception as e: - logging.warning(f"Failed to wrap layer {layer_name} in QuantizedTensorFP8: {e}") - self.quant_format = None - self.quant_scale = None - else: - logging.debug(f"Layer {layer_name} has scale but weight dtype is not FP8, skipping quantization") - self.quant_format = None - self.quant_scale = None - else: - logging.debug(f"Layer {layer_name} has quant config but no scale_weight in state_dict") - self.quant_format = None - - self._quantization_initialized = True - - def _save_to_state_dict(self, destination, prefix, keep_vars): - """Save layer parameters including quantization scale""" - # First unwrap the weight if it's quantized - if isinstance(self.weight, torch.nn.Parameter) and isinstance(self.weight.data, QuantizedTensorFP8): - # Temporarily unwrap to save the raw FP8 data - quantized_tensor = self.weight.data - raw_fp8_data = quantized_tensor._raw_data - original_weight = self.weight - self.weight = torch.nn.Parameter(raw_fp8_data, requires_grad=False) - - # Call parent to save unwrapped weight - super()._save_to_state_dict(destination, prefix, keep_vars) - - # Restore the wrapped weight - self.weight = original_weight - - # Save the scale parameter - if self.quant_scale is not None: - destination[f"{prefix}scale_weight"] = self.quant_scale if keep_vars else self.quant_scale.detach() - else: - # Standard path for non-quantized weights - super()._save_to_state_dict(destination, prefix, keep_vars) - - def forward_comfy_cast_weights(self, input): - """ - Forward pass - tensor subclass handles dispatch automatically! - __torch_dispatch__ will route to registered handlers based on tensor types. - """ - weight, bias = cast_bias_weight(self, input) - - # Call F.linear - if weight is QuantizedTensorFP8, __torch_dispatch__ handles it! - return torch.nn.functional.linear(input, weight, bias) - - def forward(self, *args, **kwargs): - """Main forward pass""" - run_every_op() - # Same logic as disable_weight_init.Linear - if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: - return self.forward_comfy_cast_weights(*args, **kwargs) - else: - return super().forward(*args, **kwargs) - - @classmethod - def conv_nd(s, dims, *args, **kwargs): - """Create Conv layer (same as disable_weight_init)""" - if dims == 2: - return s.Conv2d(*args, **kwargs) - elif dims == 3: - return s.Conv3d(*args, **kwargs) - else: - raise ValueError(f"unsupported dimensions: {dims}") + device = self.factory_kwargs["device"] + layer_name = prefix.rstrip('.') + weight_key = f"{prefix}weight" + weight = state_dict.pop(weight_key, None) + if weight is None: + raise ValueError(f"Missing weight for layer {layer_name}") + + if layer_name not in MixedPrecisionOps._layer_quant_config: + self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) + else: + quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None) + if quant_format is None: + raise ValueError(f"Unknown quantization format for layer {layer_name}") + + mixin = QUANT_FORMAT_MIXINS[quant_format] + self.layout_type = mixin["layout_type"] + + layout_params = { + 'scale': state_dict.pop(f"{prefix}weight_scale", None), + 'orig_dtype': MixedPrecisionOps._compute_dtype + } + self.weight = torch.nn.Parameter( + QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params), + requires_grad=False + ) + + for param_name, param_value in mixin["parameters"].items(): + _v = state_dict.pop(f"{prefix}{param_name}", None) + if _v is None: + continue + setattr(self, param_name, _v.to(device=device)) + + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + def _forward(self, input, weight, bias): + return torch.nn.functional.linear(input, weight, bias) + + def forward_comfy_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + self._forward(input, weight, bias) + + def forward(self, input, *args, **kwargs): + run_every_op() + + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + return self.forward_comfy_cast_weights(input, *args, **kwargs) + if (getattr(self, 'layout_type', None) is not None and + getattr(self, 'input_scale', None) is not None and + not isinstance(input, QuantizedTensor)): + input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype) + return self._forward(input, self.weight, self.bias) + def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): - """ - Select appropriate operations class for model. - - NEW: If model_config.layer_quant_config exists, returns MixedPrecisionOps (Phase 3). - LEGACY: All other paths unchanged for backward compatibility. - - Args: - weight_dtype: Weight storage dtype - compute_dtype: Computation dtype - load_device: Device for loading - disable_fast_fp8: Disable fast FP8 paths - fp8_optimizations: Enable FP8 optimizations - scaled_fp8: Legacy FP8 dtype marker - model_config: Model config object (optional, for mixed precision support) - - Returns: - Operations class (e.g., MixedPrecisionOps, fp8_ops, disable_weight_init) - """ - # NEW: Check for mixed precision + # If model_config.layer_quant_config exists, use new MixedPrecisionOps. if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config + # MixedPrecisionOps._compute_dtype = compute_dtype # TODO + MixedPrecisionOps._compute_dtype = torch.bfloat16 logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") return MixedPrecisionOps - # LEGACY paths (unchanged) fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 8e3bacbaf..3802da852 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -1,42 +1,79 @@ import torch import logging +from typing import Tuple, Dict -# ============================================================================== -# Global Operation Registry -# ============================================================================== +_LAYOUT_REGISTRY = {} +_GENERIC_UTILS = {} -# Global operation registry: torch operation → handler function -_QUANT_OP_REGISTRY = {} -def register_quant_op(torch_op): +def register_layout_op(torch_op, layout_type): """ - Decorator to register an operation handler. - + Decorator to register a layout-specific operation handler. + Args: + torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default) + layout_type: Layout class (e.g., TensorCoreFP8Layout) Example: - @register_quant_op(torch.ops.aten.linear.default) - def handle_linear_fp8(func, args, kwargs): - # Implementation + @register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout) + def fp8_linear(func, args, kwargs): + # FP8-specific linear implementation ... """ def decorator(handler_func): - _QUANT_OP_REGISTRY[torch_op] = handler_func + if torch_op not in _LAYOUT_REGISTRY: + _LAYOUT_REGISTRY[torch_op] = {} + _LAYOUT_REGISTRY[torch_op][layout_type] = handler_func return handler_func return decorator -def get_quant_handler(torch_op): - """Get registered handler for an operation""" - return _QUANT_OP_REGISTRY.get(torch_op) +def register_generic_util(torch_op): + """ + Decorator to register a generic utility that works for all layouts. + Args: + torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default) + + Example: + @register_generic_util(torch.ops.aten.detach.default) + def generic_detach(func, args, kwargs): + # Works for any layout + ... + """ + def decorator(handler_func): + _GENERIC_UTILS[torch_op] = handler_func + return handler_func + return decorator -def list_registered_ops(): - """List all registered quantized operations""" - return list(_QUANT_OP_REGISTRY.keys()) +def _get_layout_from_args(args): + for arg in args: + if isinstance(arg, QuantizedTensor): + return arg._layout_type + elif isinstance(arg, (list, tuple)): + for item in arg: + if isinstance(item, QuantizedTensor): + return item._layout_type + return None -# ============================================================================== -# comfy_kitchen Integration -# ============================================================================== +def _move_layout_params_to_device(params, device): + new_params = {} + for k, v in params.items(): + if isinstance(v, torch.Tensor): + new_params[k] = v.to(device=device) + else: + new_params[k] = v + return new_params + + +def _copy_layout_params(params): + new_params = {} + for k, v in params.items(): + if isinstance(v, torch.Tensor): + new_params[k] = v.clone() + else: + new_params[k] = v + return new_params + try: import comfy_kitchen as ck @@ -53,106 +90,144 @@ except Exception as e: logging.warning(f"comfy_kitchen import failed: {e} - using PyTorch fallbacks") -# ============================================================================== -# Quantized Tensor Subclass -# ============================================================================== - -class QuantizedTensorFP8(torch.Tensor): +class QuantizedLayout: """ - Tensor subclass for FP8 quantized data. - Automatically handles operations via __torch_dispatch__. + Base class for quantization layouts. + + A layout encapsulates the format-specific logic for quantization/dequantization + and provides a uniform interface for extracting raw tensors needed for computation. + + New quantization formats should subclass this and implement the required methods. + """ + @classmethod + def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]: + raise NotImplementedError(f"{cls.__name__} must implement quantize()") + + @staticmethod + def dequantize(qdata, **layout_params) -> torch.Tensor: + raise NotImplementedError(f"TensorLayout must implement dequantize()") + + @classmethod + def get_plain_tensors(cls, qtensor) -> torch.Tensor: + raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()") + + +class QuantizedTensor(torch.Tensor): + """ + Universal quantized tensor that works with any layout. + + This tensor subclass uses a pluggable layout system to support multiple + quantization formats (FP8, INT4, INT8, etc.) without code duplication. + + The layout_type determines format-specific behavior, while common operations + (detach, clone, to) are handled generically. + + Attributes: + _qdata: The quantized tensor data + _layout_type: Layout class (e.g., TensorCoreFP8Layout) + _layout_params: Dict with layout-specific params (scale, zero_point, etc.) """ @staticmethod - def __new__(cls, tensor, scale, orig_dtype=torch.bfloat16): + def __new__(cls, qdata, layout_type, layout_params): """ - Create a quantized FP8 tensor. + Create a quantized tensor. Args: - tensor: The FP8 tensor data (torch.float8_e4m3fn or e5m2) - scale: Scale factor for dequantization (scalar tensor) - orig_dtype: Original dtype before quantization + qdata: The quantized data tensor + layout_type: Layout class (subclass of QuantizedLayout) + layout_params: Dict with layout-specific parameters """ - return torch.Tensor._make_subclass(cls, tensor, require_grad=False) + return torch.Tensor._make_subclass(cls, qdata, require_grad=False) - def __init__(self, tensor, scale, orig_dtype=torch.bfloat16): - self._scale = scale - self._orig_dtype = orig_dtype - # Store a reference to prevent infinite recursion in dequantize - self._raw_data = tensor.contiguous() + def __init__(self, qdata, layout_type, layout_params): + self._qdata = qdata.contiguous() + self._layout_type = layout_type + self._layout_params = layout_params def __repr__(self): - return (f"QuantizedTensorFP8(shape={self.shape}, " - f"scale={self._scale:.4f}, dtype={self._orig_dtype})") + layout_name = self._layout_type.__name__ + param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2]) + return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})" + + @property + def layout_type(self): + return self._layout_type + + def __tensor_flatten__(self): + """ + Tensor flattening protocol for proper device movement. + """ + inner_tensors = ["_q_data"] + ctx = { + "layout_type": self._layout_type, + } + + tensor_params = {} + non_tensor_params = {} + for k, v in self._layout_params.items(): + if isinstance(v, torch.Tensor): + tensor_params[k] = v + else: + non_tensor_params[k] = v + + ctx["tensor_param_keys"] = list(tensor_params.keys()) + ctx["non_tensor_params"] = non_tensor_params + + for k, v in tensor_params.items(): + attr_name = f"_layout_param_{k}" + object.__setattr__(self, attr_name, v) + inner_tensors.append(attr_name) + + return inner_tensors, ctx + + @staticmethod + def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): + """ + Tensor unflattening protocol for proper device movement. + Reconstructs the QuantizedTensor after device movement. + """ + layout_type = ctx["layout_type"] + layout_params = dict(ctx["non_tensor_params"]) + + for key in ctx["tensor_param_keys"]: + attr_name = f"_layout_param_{key}" + layout_params[key] = inner_tensors[attr_name] + + return QuantizedTensor(inner_tensors["_q_data"], layout_type, layout_params) @classmethod - def quantize(cls, tensor, scale, fp8_dtype=torch.float8_e4m3fn): - orig_dtype = tensor.dtype - - if not isinstance(scale, torch.Tensor): - scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32) - - tensor_fp8 = None - if _CK_AVAILABLE: - try: - tensor_fp8 = ck.quantize_per_tensor_fp8(tensor, scale, fp8_dtype) - except Exception as e: - logging.debug(f"comfy_kitchen quantization failed, using PyTorch: {e}") - - if tensor_fp8 is None: - lp_amax = torch.finfo(fp8_dtype).max - tensor_scaled = tensor.float() / scale - torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) - tensor_fp8 = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format) - - return cls(tensor_fp8, scale, orig_dtype=orig_dtype) + def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor': + qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs) + return cls(qdata, layout_type, layout_params) - @classmethod - def quantize_dynamic(cls, tensor, strategy="amax", fp8_dtype=torch.float8_e4m3fn): - if strategy == "amax": - scale = torch.amax(tensor) / torch.finfo(fp8_dtype).max - scale = scale.to(tensor.device, dtype=torch.float32) - else: - raise ValueError(f"Unknown quantization strategy: {strategy}. " - f"Supported: 'amax'") - - return cls.quantize(tensor, scale, fp8_dtype=fp8_dtype) + def dequantize(self) -> torch.Tensor: + return self._layout_type.dequantize(self._qdata, **self._layout_params) @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): kwargs = kwargs or {} - # Special case: skip dispatch for internal tensor operations - # that are used for unwrapping (to avoid recursion) - if func in [torch.ops.aten._to_copy.default, torch.ops.aten.detach.default]: - # For these ops, use the raw data to avoid recursion, but return QuantizedTensorFP8 for detach - if func == torch.ops.aten.detach.default and isinstance(args[0], QuantizedTensorFP8): - # Special handling for detach - return a new QuantizedTensorFP8 - qt = args[0] - detached_data = qt._raw_data.detach() - return QuantizedTensorFP8(detached_data, qt._scale, qt._orig_dtype) - - # For other ops, just unwrap - def unwrap(arg): - if isinstance(arg, QuantizedTensorFP8): - return arg._raw_data - return arg - new_args = tuple(unwrap(a) if not isinstance(a, (list, tuple, dict)) else a for a in args) - return func(*new_args, **kwargs) + # Step 1: Check generic utilities first (detach, clone, to, etc.) + if func in _GENERIC_UTILS: + return _GENERIC_UTILS[func](func, args, kwargs) - # Look up registered handler for this operation - handler = _QUANT_OP_REGISTRY.get(func) - if handler: - return handler(func, args, kwargs) + # Step 2: Check layout-specific handlers (linear, matmul, etc.) + layout_type = _get_layout_from_args(args) + if layout_type and func in _LAYOUT_REGISTRY: + handler = _LAYOUT_REGISTRY[func].get(layout_type) + if handler: + return handler(func, args, kwargs) - # No handler - dequantize and use standard path + # Step 3: Fallback to dequantization + if isinstance(args[0] if args else None, QuantizedTensor): + logging.warning(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") return cls._dequant_and_fallback(func, args, kwargs) @classmethod def _dequant_and_fallback(cls, func, args, kwargs): - """Fallback: dequantize all quantized tensors""" def dequant_arg(arg): - if isinstance(arg, QuantizedTensorFP8): + if isinstance(arg, QuantizedTensor): return arg.dequantize() elif isinstance(arg, (list, tuple)): return type(arg)(dequant_arg(a) for a in arg) @@ -161,75 +236,220 @@ class QuantizedTensorFP8(torch.Tensor): new_args = dequant_arg(args) new_kwargs = dequant_arg(kwargs) return func(*new_args, **new_kwargs) - - def dequantize(self) -> torch.Tensor: - plain_tensor = torch.ops.aten._to_copy.default(self._raw_data, dtype=self._orig_dtype) - return plain_tensor * self._scale - - def detach(self): - """Detach returns a new QuantizedTensorFP8 (required for Parameter)""" - detached_data = self._raw_data.detach() - return QuantizedTensorFP8(detached_data, self._scale, self._orig_dtype) # ============================================================================== -# Operation Handlers for Quantized Tensors +# Generic Utilities (Layout-Agnostic Operations) # ============================================================================== -@register_quant_op(torch.ops.aten.linear.default) -def handle_linear_fp8(func, args, kwargs): - """ - Handle F.linear() with quantized inputs. +def _create_transformed_qtensor(qt, transform_fn): + new_data = transform_fn(qt._qdata) + new_params = _copy_layout_params(qt._layout_params) + return QuantizedTensor(new_data, qt._layout_type, new_params) + + +def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"): + if target_dtype is not None and target_dtype != qt.dtype: + logging.warning( + f"QuantizedTensor: dtype conversion requested to {target_dtype}, " + f"but not supported for quantized tensors. Ignoring dtype." + ) - Supports: - - QuantizedTensorFP8 input + QuantizedTensorFP8 weight - - QuantizedTensorFP8 input + regular weight - - Regular input + QuantizedTensorFP8 weight + if target_layout is not None and target_layout != torch.strided: + logging.warning( + f"QuantizedTensor: layout change requested to {target_layout}, " + f"but not supported. Ignoring layout." + ) + + # Handle device transfer + current_device = qt._qdata.device + if target_device is not None: + # Normalize device for comparison + if isinstance(target_device, str): + target_device = torch.device(target_device) + if isinstance(current_device, str): + current_device = torch.device(current_device) + + if target_device != current_device: + logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}") + new_q_data = qt._qdata.to(device=target_device) + new_params = _move_layout_params_to_device(qt._layout_params, target_device) + new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params) + logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}") + return new_qt + + logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original") + return qt + + +@register_generic_util(torch.ops.aten.detach.default) +def generic_detach(func, args, kwargs): + """Detach operation - creates a detached copy of the quantized tensor.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _create_transformed_qtensor(qt, lambda x: x.detach()) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten.clone.default) +def generic_clone(func, args, kwargs): + """Clone operation - creates a deep copy of the quantized tensor.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _create_transformed_qtensor(qt, lambda x: x.clone()) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten._to_copy.default) +def generic_to_copy(func, args, kwargs): + """Device/dtype transfer operation - handles .to(device) calls.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _handle_device_transfer( + qt, + target_device=kwargs.get('device', None), + target_dtype=kwargs.get('dtype', None), + op_name="_to_copy" + ) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten.to.dtype_layout) +def generic_to_dtype_layout(func, args, kwargs): + """Handle .to(device) calls using the dtype_layout variant.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _handle_device_transfer( + qt, + target_device=kwargs.get('device', None), + target_dtype=kwargs.get('dtype', None), + target_layout=kwargs.get('layout', None), + op_name="to" + ) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten.copy_.default) +def generic_copy_(func, args, kwargs): + qt_dest = args[0] + src = args[1] + + if isinstance(qt_dest, QuantizedTensor): + if isinstance(src, QuantizedTensor): + # Copy from another quantized tensor + qt_dest._qdata.copy_(src._qdata) + qt_dest._layout_type = src._layout_type + qt_dest._layout_params = _copy_layout_params(src._layout_params) + else: + # Copy from regular tensor - just copy raw data + qt_dest._qdata.copy_(src) + return qt_dest + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default) +def generic_has_compatible_shallow_copy_type(func, args, kwargs): + return True + +# ============================================================================== +# FP8 Layout + Operation Handlers +# ============================================================================== +class TensorCoreFP8Layout(QuantizedLayout): """ + Storage format: + - qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2) + - scale: Scalar tensor (float32) for dequantization + - orig_dtype: Original dtype before quantization (for casting back) + """ + @classmethod + def quantize(cls, tensor, scale=None, fp8_dtype=torch.float8_e4m3fn): + orig_dtype = tensor.dtype + + if scale is None: + scale = torch.amax(tensor.abs()) / torch.finfo(fp8_dtype).max + + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale) + scale = scale.to(device=tensor.device, dtype=torch.float32) + + if _CK_AVAILABLE and tensor.device.type == "cuda": + qdata = ck.quantize_per_tensor_fp8(tensor, scale, fp8_dtype) + else: + lp_amax = torch.finfo(fp8_dtype).max + tensor_scaled = tensor.float() / scale + torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) + qdata = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format) + + layout_params = { + 'scale': scale, + 'orig_dtype': orig_dtype + } + return qdata, layout_params + + @staticmethod + def dequantize(qdata, scale, orig_dtype, **kwargs): + plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype) + return plain_tensor * scale + + @classmethod + def get_plain_tensors(cls, qtensor): + return qtensor._qdata, qtensor._layout_params['scale'] + + +@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout) +def fp8_linear(func, args, kwargs): input_tensor = args[0] weight = args[1] bias = args[2] if len(args) > 2 else None - out_dtype = kwargs.get("out_dtype", input_tensor._orig_dtype) - - # Case 1: Both input and weight are FP8 - if isinstance(input_tensor, QuantizedTensorFP8) and isinstance(weight, QuantizedTensorFP8): - # Get plain tensors to avoid dispatch recursion - plain_input = input_tensor._raw_data - plain_weight = weight._raw_data - weight_t = plain_weight.t() # Keep as column-major for cuBLASLt + + if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): + plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) + plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) + out_dtype = kwargs.get("out_dtype") + if out_dtype is None: + out_dtype = input_tensor._layout_params['orig_dtype'] + + weight_t = plain_weight.t() + + tensor_2d = False + if len(plain_input.shape) == 2: + tensor_2d = True + plain_input = plain_input.unsqueeze(1) + + input_shape = plain_input.shape + if len(input_shape) != 3: + return None + try: output = torch._scaled_mm( - plain_input, + plain_input.reshape(-1, input_shape[2]), weight_t, bias=bias, - scale_a=input_tensor._scale, - scale_b=weight._scale, + scale_a=scale_a, + scale_b=scale_b, out_dtype=out_dtype, ) - if isinstance(output, tuple): - output = output[0] - + if not tensor_2d: + output = output.reshape((-1, input_shape[1], weight.shape[0])) + if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - output_scale = input_tensor._scale * weight._scale - return QuantizedTensorFP8(output, output_scale, input_tensor._orig_dtype) # TODO is this correct? Can't cuBLAS return it + output_scale = scale_a * scale_b + output_params = { + 'scale': output_scale, + 'orig_dtype': input_tensor._layout_params['orig_dtype'] + } + return QuantizedTensor(output, TensorCoreFP8Layout, output_params) else: return output + except Exception as e: - logging.debug(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") + raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") - # Case 2: Only weight is quantized - if isinstance(weight, QuantizedTensorFP8): - weight_dq = weight.dequantize() - input_dq = input_tensor.dequantize() if isinstance(input_tensor, QuantizedTensorFP8) else input_tensor - return torch.nn.functional.linear(input_dq, weight_dq, bias) - - # Case 3: Only input is quantized - elif isinstance(input_tensor, QuantizedTensorFP8): - input_dq = input_tensor.dequantize() - return torch.nn.functional.linear(input_dq, weight, bias) - - # Case 4: Neither is quantized (shouldn't happen, but handle it) - else: - return torch.nn.functional.linear(input_tensor, weight, bias) + # Case 2: DQ Fallback + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + if isinstance(input_tensor, QuantizedTensor): + input_tensor = input_tensor.dequantize() + return torch.nn.functional.linear(input_tensor, weight, bias) diff --git a/comfy/sd.py b/comfy/sd.py index 28bee248d..b965e9842 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1262,7 +1262,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c return (model_patcher, clip, vae, clipvision) -def load_diffusion_model_state_dict(sd, model_options={}): +def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): """ Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats. @@ -1296,7 +1296,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): weight_dtype = comfy.utils.weight_dtype(sd) load_device = model_management.get_torch_device() - model_config = model_detection.model_config_from_unet(sd, "") + model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata) if model_config is not None: new_sd = sd @@ -1346,8 +1346,8 @@ def load_diffusion_model_state_dict(sd, model_options={}): def load_diffusion_model(unet_path, model_options={}): - sd = comfy.utils.load_torch_file(unet_path) - model = load_diffusion_model_state_dict(sd, model_options=model_options) + sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True) + model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata) if model is None: logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd))) From efb35035f3531c6328a6eeff90fc88e873baa437 Mon Sep 17 00:00:00 2001 From: lspindler Date: Mon, 27 Oct 2025 07:55:44 +0100 Subject: [PATCH 5/7] Remove CK reference and ensure correct compute dtype --- comfy/model_detection.py | 12 ++---------- comfy/ops.py | 4 +--- comfy/quant_ops.py | 26 ++++---------------------- comfy/sd.py | 5 ++++- 4 files changed, 11 insertions(+), 36 deletions(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index ffb1885fd..335ccbd17 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -7,8 +7,7 @@ import logging import torch -def detect_layer_quantization(state_dict, metadata, prefix="model.diffusion_model."): - # 1. Check for per-layer config in metadata +def detect_layer_quantization(metadata): quant_key = "_quantization_metadata" if metadata is not None and quant_key in metadata: quant_metadata = metadata.pop(quant_key) @@ -18,13 +17,6 @@ def detect_layer_quantization(state_dict, metadata, prefix="model.diffusion_mode return quant_metadata["layers"] else: raise ValueError(f"Invalid quantization metadata format") - - # 2. Check for legacy scaled_fp8 marker - scaled_fp8_key = f"{prefix}scaled_fp8" - if scaled_fp8_key in state_dict: - logging.debug("Detected legacy scaled_fp8 format, using legacy code path") - return None - return None @@ -724,7 +716,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal model_config.optimizations["fp8"] = True # Detect per-layer quantization (mixed precision) - layer_quant_config = detect_layer_quantization(state_dict, metadata, unet_key_prefix) + layer_quant_config = detect_layer_quantization(metadata) if layer_quant_config: model_config.layer_quant_config = layer_quant_config logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized") diff --git a/comfy/ops.py b/comfy/ops.py index 8d11aeefc..5edd4daa2 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -561,11 +561,9 @@ class MixedPrecisionOps(disable_weight_init): def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): - # If model_config.layer_quant_config exists, use new MixedPrecisionOps. if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config - # MixedPrecisionOps._compute_dtype = compute_dtype # TODO - MixedPrecisionOps._compute_dtype = torch.bfloat16 + MixedPrecisionOps._compute_dtype = compute_dtype logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") return MixedPrecisionOps diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 3802da852..8d7f6480a 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -75,21 +75,6 @@ def _copy_layout_params(params): return new_params -try: - import comfy_kitchen as ck - ck.disable_backend("cutile") - _CK_AVAILABLE = True - logging.info("comfy_kitchen available for optimized quantization kernels") -except ImportError: - ck = None - _CK_AVAILABLE = False - logging.info("comfy_kitchen not available - using PyTorch fallbacks") -except Exception as e: - ck = None - _CK_AVAILABLE = False - logging.warning(f"comfy_kitchen import failed: {e} - using PyTorch fallbacks") - - class QuantizedLayout: """ Base class for quantization layouts. @@ -372,13 +357,10 @@ class TensorCoreFP8Layout(QuantizedLayout): scale = torch.tensor(scale) scale = scale.to(device=tensor.device, dtype=torch.float32) - if _CK_AVAILABLE and tensor.device.type == "cuda": - qdata = ck.quantize_per_tensor_fp8(tensor, scale, fp8_dtype) - else: - lp_amax = torch.finfo(fp8_dtype).max - tensor_scaled = tensor.float() / scale - torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) - qdata = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format) + lp_amax = torch.finfo(fp8_dtype).max + tensor_scaled = tensor.float() / scale + torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) + qdata = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format) layout_params = { 'scale': scale, diff --git a/comfy/sd.py b/comfy/sd.py index b965e9842..6411bb27d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1330,7 +1330,10 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): else: unet_dtype = dtype - manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) + if hasattr(model_config, "layer_quant_config"): + manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) + else: + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations) if model_options.get("fp8_optimizations", False): From a7216e18e5cf40d0dcbadd2f4e4c03a0c3f38f49 Mon Sep 17 00:00:00 2001 From: lspindler Date: Mon, 27 Oct 2025 08:41:23 +0100 Subject: [PATCH 6/7] Update unit tests --- comfy/ops.py | 2 +- comfy/quant_ops.py | 4 +- .../test_mixed_precision.py | 147 +++---- tests-unit/comfy_quant/test_quant_registry.py | 183 ++++++++ tests-unit/comfy_test/test_quant_detection.py | 262 ------------ tests-unit/comfy_test/test_quant_registry.py | 399 ------------------ 6 files changed, 235 insertions(+), 762 deletions(-) rename tests-unit/{comfy_test => comfy_quant}/test_mixed_precision.py (60%) create mode 100644 tests-unit/comfy_quant/test_quant_registry.py delete mode 100644 tests-unit/comfy_test/test_quant_detection.py delete mode 100644 tests-unit/comfy_test/test_quant_registry.py diff --git a/comfy/ops.py b/comfy/ops.py index 5edd4daa2..8af1e949d 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -546,7 +546,7 @@ class MixedPrecisionOps(disable_weight_init): def forward_comfy_cast_weights(self, input): weight, bias = cast_bias_weight(self, input) - self._forward(input, weight, bias) + return self._forward(input, weight, bias) def forward(self, input, *args, **kwargs): run_every_op() diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 8d7f6480a..96d2fa03f 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -143,7 +143,7 @@ class QuantizedTensor(torch.Tensor): """ Tensor flattening protocol for proper device movement. """ - inner_tensors = ["_q_data"] + inner_tensors = ["_qdata"] ctx = { "layout_type": self._layout_type, } @@ -206,7 +206,7 @@ class QuantizedTensor(torch.Tensor): # Step 3: Fallback to dequantization if isinstance(args[0] if args else None, QuantizedTensor): - logging.warning(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") + logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") return cls._dequant_and_fallback(func, args, kwargs) @classmethod diff --git a/tests-unit/comfy_test/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py similarity index 60% rename from tests-unit/comfy_test/test_mixed_precision.py rename to tests-unit/comfy_quant/test_mixed_precision.py index cbfa2866d..e34552760 100644 --- a/tests-unit/comfy_test/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -1,8 +1,3 @@ -""" -End-to-end tests for mixed precision quantization. -Tests Phase 3: Mixed Precision Operations -""" - import unittest import torch import sys @@ -12,10 +7,10 @@ import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) from comfy import ops +from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout class SimpleModel(torch.nn.Module): - """Simple model for testing mixed precision""" def __init__(self, operations=ops.disable_weight_init): super().__init__() self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16) @@ -32,8 +27,7 @@ class SimpleModel(torch.nn.Module): class TestMixedPrecisionOps(unittest.TestCase): - """Test MixedPrecisionOps end-to-end""" - + def test_all_layers_standard(self): """Test that model with no quantization works normally""" # Configure no quantization @@ -67,48 +61,54 @@ class TestMixedPrecisionOps(unittest.TestCase): # Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard layer_quant_config = { "layer1": { - "format": "fp8_e4m3fn_scaled", - "params": {"use_fp8_matmul": False} # Disable for CPU testing + "format": "float8_e4m3fn", + "params": {} }, "layer3": { - "format": "fp8_e5m2_scaled", - "params": {"use_fp8_matmul": False} + "format": "float8_e4m3fn", + "params": {} } } ops.MixedPrecisionOps._layer_quant_config = layer_quant_config # Create state dict with mixed precision fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) - fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e5m2) + fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e4m3fn) state_dict = { # Layer 1: FP8 E4M3FN "layer1.weight": fp8_weight1, "layer1.bias": torch.randn(20, dtype=torch.bfloat16), - "layer1.scale_weight": torch.tensor(2.0, dtype=torch.float32), + "layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32), # Layer 2: Standard BF16 "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), "layer2.bias": torch.randn(30, dtype=torch.bfloat16), - # Layer 3: FP8 E5M2 + # Layer 3: FP8 E4M3FN "layer3.weight": fp8_weight3, "layer3.bias": torch.randn(40, dtype=torch.bfloat16), - "layer3.scale_weight": torch.tensor(1.5, dtype=torch.float32), + "layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32), } - # Create model and load state dict + # Create model and load state dict (strict=False because custom loading pops keys) model = SimpleModel(operations=ops.MixedPrecisionOps) - model.load_state_dict(state_dict) + model.load_state_dict(state_dict, strict=False) - # Verify handlers are set up correctly - self.assertIsNotNone(model.layer1.quant_handler) - self.assertIsNone(model.layer2.quant_handler) # No quantization - self.assertIsNotNone(model.layer3.quant_handler) + # Verify weights are wrapped in QuantizedTensor + self.assertIsInstance(model.layer1.weight, QuantizedTensor) + self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout) + + # Layer 2 should NOT be quantized + self.assertNotIsInstance(model.layer2.weight, QuantizedTensor) + + # Layer 3 should be quantized + self.assertIsInstance(model.layer3.weight, QuantizedTensor) + self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout) # Verify scales were loaded - self.assertEqual(model.layer1.scale_weight.item(), 2.0) - self.assertEqual(model.layer3.scale_weight.item(), 1.5) + self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0) + self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5) # Forward pass input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) @@ -116,13 +116,13 @@ class TestMixedPrecisionOps(unittest.TestCase): self.assertEqual(output.shape, (5, 40)) - def test_state_dict_round_trip(self): - """Test saving and loading state dict preserves quantization""" + def test_state_dict_quantized_preserved(self): + """Test that quantized weights are preserved in state_dict()""" # Configure mixed precision layer_quant_config = { "layer1": { - "format": "fp8_e4m3fn_scaled", - "params": {"use_fp8_matmul": False} + "format": "float8_e4m3fn", + "params": {} } } ops.MixedPrecisionOps._layer_quant_config = layer_quant_config @@ -132,45 +132,35 @@ class TestMixedPrecisionOps(unittest.TestCase): state_dict1 = { "layer1.weight": fp8_weight, "layer1.bias": torch.randn(20, dtype=torch.bfloat16), - "layer1.scale_weight": torch.tensor(3.0, dtype=torch.float32), + "layer1.weight_scale": torch.tensor(3.0, dtype=torch.float32), "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), "layer2.bias": torch.randn(30, dtype=torch.bfloat16), "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - model1 = SimpleModel(operations=ops.MixedPrecisionOps) - model1.load_state_dict(state_dict1) + model = SimpleModel(operations=ops.MixedPrecisionOps) + model.load_state_dict(state_dict1, strict=False) # Save state dict - state_dict2 = model1.state_dict() + state_dict2 = model.state_dict() - # Verify scale_weight is saved - self.assertIn("layer1.scale_weight", state_dict2) - self.assertEqual(state_dict2["layer1.scale_weight"].item(), 3.0) + # Verify layer1.weight is a QuantizedTensor with scale preserved + self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor) + self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0) + self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout) - # Load into new model - model2 = SimpleModel(operations=ops.MixedPrecisionOps) - model2.load_state_dict(state_dict2) - - # Verify handler is set up - self.assertIsNotNone(model2.layer1.quant_handler) - self.assertEqual(model2.layer1.scale_weight.item(), 3.0) - - # Verify forward passes match - input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) - output1 = model1(input_tensor) - output2 = model2(input_tensor) - - torch.testing.assert_close(output1, output2, rtol=1e-3, atol=1e-3) + # Verify non-quantized layers are standard tensors + self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor) + self.assertNotIsInstance(state_dict2["layer3.weight"], QuantizedTensor) def test_weight_function_compatibility(self): """Test that weight_function (LoRA) works with quantized layers""" # Configure FP8 quantization layer_quant_config = { "layer1": { - "format": "fp8_e4m3fn_scaled", - "params": {"use_fp8_matmul": False} + "format": "float8_e4m3fn", + "params": {} } } ops.MixedPrecisionOps._layer_quant_config = layer_quant_config @@ -180,7 +170,7 @@ class TestMixedPrecisionOps(unittest.TestCase): state_dict = { "layer1.weight": fp8_weight, "layer1.bias": torch.randn(20, dtype=torch.bfloat16), - "layer1.scale_weight": torch.tensor(2.0, dtype=torch.float32), + "layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32), "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), "layer2.bias": torch.randn(30, dtype=torch.bfloat16), "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), @@ -188,25 +178,24 @@ class TestMixedPrecisionOps(unittest.TestCase): } model = SimpleModel(operations=ops.MixedPrecisionOps) - model.load_state_dict(state_dict) + model.load_state_dict(state_dict, strict=False) # Add a weight function (simulating LoRA) - # LoRA delta must match weight shape (20, 10) + # This should trigger dequantization during forward pass def apply_lora(weight): - # Generate LoRA delta matching weight shape lora_delta = torch.randn_like(weight) * 0.01 return weight + lora_delta model.layer1.weight_function.append(apply_lora) - # Forward pass should work with LoRA + # Forward pass should work with LoRA (triggers weight_function path) input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) output = model(input_tensor) self.assertEqual(output.shape, (5, 40)) def test_error_handling_unknown_format(self): - """Test that unknown formats fall back gracefully""" + """Test that unknown formats raise error""" # Configure with unknown format layer_quant_config = { "layer1": { @@ -226,48 +215,10 @@ class TestMixedPrecisionOps(unittest.TestCase): "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - # Load should not crash, just log warning + # Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS model = SimpleModel(operations=ops.MixedPrecisionOps) - model.load_state_dict(state_dict) - - # Handler should be None (fallback to standard) - self.assertIsNone(model.layer1.quant_handler) - - # Forward pass should still work - input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) - output = model(input_tensor) - self.assertEqual(output.shape, (5, 40)) - - -class TestPickOperationsWithMixedPrecision(unittest.TestCase): - """Test pick_operations with mixed precision config""" - - def test_pick_operations_with_layer_quant_config(self): - """Test that pick_operations returns MixedPrecisionOps when config present""" - from comfy import supported_models_base - - # Create model config with layer_quant_config - model_config = supported_models_base.BASE({}) - model_config.layer_quant_config = { - "layer1": {"format": "fp8_e4m3fn_scaled", "params": {}} - } - - result = ops.pick_operations(None, None, model_config=model_config) - - self.assertEqual(result, ops.MixedPrecisionOps) - self.assertEqual(ops.MixedPrecisionOps._layer_quant_config, model_config.layer_quant_config) - - def test_pick_operations_without_layer_quant_config(self): - """Test that pick_operations falls back to standard when no config""" - from comfy import supported_models_base - - model_config = supported_models_base.BASE({}) - model_config.layer_quant_config = None - - result = ops.pick_operations(None, None, model_config=model_config) - - self.assertEqual(result, ops.disable_weight_init) - + with self.assertRaises(KeyError): + model.load_state_dict(state_dict, strict=False) if __name__ == "__main__": unittest.main() diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py new file mode 100644 index 000000000..263581417 --- /dev/null +++ b/tests-unit/comfy_quant/test_quant_registry.py @@ -0,0 +1,183 @@ +import unittest +import torch +import sys +import os + +# Add comfy to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout + + +class TestQuantizedTensor(unittest.TestCase): + """Test the QuantizedTensor subclass with FP8 layout""" + + def test_creation(self): + """Test creating a QuantizedTensor with TensorCoreFP8Layout""" + fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(2.0) + layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16} + + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + + self.assertIsInstance(qt, QuantizedTensor) + self.assertEqual(qt.shape, (256, 128)) + self.assertEqual(qt.dtype, torch.float8_e4m3fn) + self.assertEqual(qt._layout_params['scale'], scale) + self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16) + self.assertEqual(qt._layout_type, TensorCoreFP8Layout) + + def test_dequantize(self): + """Test explicit dequantization""" + + fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(3.0) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + dequantized = qt.dequantize() + + self.assertEqual(dequantized.dtype, torch.float32) + self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) + + def test_from_float(self): + """Test creating QuantizedTensor from float tensor""" + float_tensor = torch.randn(64, 32, dtype=torch.float32) + scale = torch.tensor(1.5) + + qt = QuantizedTensor.from_float( + float_tensor, + TensorCoreFP8Layout, + scale=scale, + fp8_dtype=torch.float8_e4m3fn + ) + + self.assertIsInstance(qt, QuantizedTensor) + self.assertEqual(qt.dtype, torch.float8_e4m3fn) + self.assertEqual(qt.shape, (64, 32)) + + # Verify dequantization gives approximately original values + dequantized = qt.dequantize() + mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean() + self.assertLess(mean_rel_error, 0.1) + + +class TestGenericUtilities(unittest.TestCase): + """Test generic utility operations""" + + def test_detach(self): + """Test detach operation on quantized tensor""" + fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(1.5) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + + # Detach should return a new QuantizedTensor + qt_detached = qt.detach() + + self.assertIsInstance(qt_detached, QuantizedTensor) + self.assertEqual(qt_detached.shape, qt.shape) + self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout) + + def test_clone(self): + """Test clone operation on quantized tensor""" + fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(1.5) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + + # Clone should return a new QuantizedTensor + qt_cloned = qt.clone() + + self.assertIsInstance(qt_cloned, QuantizedTensor) + self.assertEqual(qt_cloned.shape, qt.shape) + self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout) + + # Verify it's a deep copy + self.assertIsNot(qt_cloned._qdata, qt._qdata) + + def test_to_device(self): + """Test device transfer""" + fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(1.5) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + + # Moving to same device should work (CPU to CPU) + qt_cpu = qt.to('cpu') + + self.assertIsInstance(qt_cpu, QuantizedTensor) + self.assertEqual(qt_cpu.device.type, 'cpu') + self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu') + + +class TestTensorCoreFP8Layout(unittest.TestCase): + """Test the TensorCoreFP8Layout implementation""" + + def test_quantize(self): + """Test quantization method""" + float_tensor = torch.randn(32, 64, dtype=torch.float32) + scale = torch.tensor(1.5) + + qdata, layout_params = TensorCoreFP8Layout.quantize( + float_tensor, + scale=scale, + fp8_dtype=torch.float8_e4m3fn + ) + + self.assertEqual(qdata.dtype, torch.float8_e4m3fn) + self.assertEqual(qdata.shape, float_tensor.shape) + self.assertIn('scale', layout_params) + self.assertIn('orig_dtype', layout_params) + self.assertEqual(layout_params['orig_dtype'], torch.float32) + + def test_dequantize(self): + """Test dequantization method""" + float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0 + scale = torch.tensor(1.0) + + qdata, layout_params = TensorCoreFP8Layout.quantize( + float_tensor, + scale=scale, + fp8_dtype=torch.float8_e4m3fn + ) + + dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params) + + # Should approximately match original + self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1)) + + +class TestFallbackMechanism(unittest.TestCase): + """Test fallback for unsupported operations""" + + def test_unsupported_op_dequantizes(self): + """Test that unsupported operations fall back to dequantization""" + # Set seed for reproducibility + torch.manual_seed(42) + + # Create quantized tensor + a_fp32 = torch.randn(10, 20, dtype=torch.float32) + scale = torch.tensor(1.0) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + a_q = QuantizedTensor.from_float( + a_fp32, + TensorCoreFP8Layout, + scale=scale, + fp8_dtype=torch.float8_e4m3fn + ) + + # Call an operation that doesn't have a registered handler + # For example, torch.abs + result = torch.abs(a_q) + + # Should work via fallback (dequantize → abs → return) + self.assertNotIsInstance(result, QuantizedTensor) + expected = torch.abs(a_fp32) + # FP8 introduces quantization error, so use loose tolerance + mean_error = (result - expected).abs().mean() + self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests-unit/comfy_test/test_quant_detection.py b/tests-unit/comfy_test/test_quant_detection.py deleted file mode 100644 index bb952a81b..000000000 --- a/tests-unit/comfy_test/test_quant_detection.py +++ /dev/null @@ -1,262 +0,0 @@ -""" -Integration tests for quantization detection. -Tests Phase 2: Detection & Integration -""" - -import unittest -import torch -import sys -import os - -# Add comfy to path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) - -from comfy import model_detection - - -class TestNormalizeLayerName(unittest.TestCase): - """Test the normalize_layer_name helper function""" - - def test_strip_prefix_and_suffix(self): - """Test stripping prefix and suffix""" - known_prefixes = ["model.diffusion_model."] - result = model_detection.normalize_layer_name( - "model.diffusion_model.layer1.weight", - known_prefixes - ) - self.assertEqual(result, "layer1") - - def test_strip_multiple_prefixes(self): - """Test with multiple known prefixes""" - known_prefixes = ["model.diffusion_model.", "model.model.", "net."] - - result1 = model_detection.normalize_layer_name( - "model.diffusion_model.block.attn.weight", - known_prefixes - ) - self.assertEqual(result1, "block.attn") - - result2 = model_detection.normalize_layer_name( - "model.model.encoder.layer.weight", - known_prefixes - ) - self.assertEqual(result2, "encoder.layer") - - result3 = model_detection.normalize_layer_name( - "net.transformer.blocks.0.weight", - known_prefixes - ) - self.assertEqual(result3, "transformer.blocks.0") - - def test_strip_scale_weight_suffix(self): - """Test stripping scale_weight suffix""" - known_prefixes = ["model.diffusion_model."] - result = model_detection.normalize_layer_name( - "model.diffusion_model.layer1.scale_weight", - known_prefixes - ) - self.assertEqual(result, "layer1") - - def test_strip_bias_suffix(self): - """Test stripping bias suffix""" - known_prefixes = ["model.diffusion_model."] - result = model_detection.normalize_layer_name( - "model.diffusion_model.layer1.bias", - known_prefixes - ) - self.assertEqual(result, "layer1") - - def test_no_prefix_match(self): - """Test with no prefix match""" - known_prefixes = ["model.diffusion_model."] - result = model_detection.normalize_layer_name( - "other.model.layer1.weight", - known_prefixes - ) - # Should strip suffix but not prefix - self.assertEqual(result, "other.model.layer1") - - -class TestDetectLayerQuantization(unittest.TestCase): - """Test the detect_layer_quantization function""" - - def test_no_quantization(self): - """Test with no quantization markers""" - state_dict = { - "model.diffusion_model.layer1.weight": torch.randn(10, 20), - "model.diffusion_model.layer2.weight": torch.randn(20, 30), - } - result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") - self.assertIsNone(result) - - def test_legacy_scaled_fp8(self): - """Test that legacy scaled_fp8 marker returns None""" - # Create FP8 tensor by converting from float32 - fp8_weight = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - state_dict = { - "model.diffusion_model.scaled_fp8": torch.tensor([], dtype=torch.float8_e4m3fn), - "model.diffusion_model.layer1.weight": fp8_weight, - "model.diffusion_model.layer1.scale_weight": torch.tensor(1.0), - } - result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") - # Should return None to trigger legacy path - self.assertIsNone(result) - - def test_metadata_format(self): - """Test with new metadata format""" - metadata = { - "format_version": "1.0", - "layers": { - "layer1": { - "format": "fp8_e4m3fn_scaled", - "params": {"use_fp8_matmul": True} - }, - "layer2": { - "format": "fp8_e5m2_scaled", - "params": {"use_fp8_matmul": True} - } - } - } - state_dict = { - "model.diffusion_model._quantization_metadata": metadata, - "model.diffusion_model.layer1.weight": torch.randn(10, 20), - } - result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") - - self.assertIsNotNone(result) - self.assertIn("layer1", result) - self.assertIn("layer2", result) - self.assertEqual(result["layer1"]["format"], "fp8_e4m3fn_scaled") - self.assertEqual(result["layer2"]["format"], "fp8_e5m2_scaled") - # Metadata should be popped from state_dict - self.assertNotIn("model.diffusion_model._quantization_metadata", state_dict) - - def test_mixed_precision_detection(self): - """Test detection of mixed precision via scale patterns""" - # Create FP8 tensors by converting from float32 - fp8_weight1 = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - fp8_weight3 = torch.randn(30, 40, dtype=torch.float32).to(torch.float8_e4m3fn) - state_dict = { - # Layer 1: FP8 (has scale_weight) - "model.diffusion_model.layer1.weight": fp8_weight1, - "model.diffusion_model.layer1.scale_weight": torch.tensor(1.0), - # Layer 2: Standard (no scale_weight) - "model.diffusion_model.layer2.weight": torch.randn(20, 30, dtype=torch.bfloat16), - # Layer 3: FP8 (has scale_weight) - "model.diffusion_model.layer3.weight": fp8_weight3, - "model.diffusion_model.layer3.scale_weight": torch.tensor(1.0), - } - result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") - - self.assertIsNotNone(result) - self.assertIn("layer1", result) - self.assertIn("layer3", result) - self.assertNotIn("layer2", result) # Layer 2 not quantized - self.assertEqual(result["layer1"]["format"], "fp8_e4m3fn_scaled") - self.assertEqual(result["layer3"]["format"], "fp8_e4m3fn_scaled") - - def test_all_layers_quantized(self): - """Test that uniform quantization (all layers) returns None""" - # Create FP8 tensors by converting from float32 - fp8_weight1 = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - fp8_weight2 = torch.randn(20, 30, dtype=torch.float32).to(torch.float8_e4m3fn) - state_dict = { - # All layers have scale_weight - "model.diffusion_model.layer1.weight": fp8_weight1, - "model.diffusion_model.layer1.scale_weight": torch.tensor(1.0), - "model.diffusion_model.layer2.weight": fp8_weight2, - "model.diffusion_model.layer2.scale_weight": torch.tensor(1.0), - } - result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") - - # If all layers are quantized, it's not mixed precision - # Should return None to use legacy scaled_fp8_ops path - self.assertIsNone(result) - - def test_fp8_e5m2_detection(self): - """Test detection of FP8 E5M2 format""" - # Create FP8 E5M2 tensor by converting from float32 - fp8_weight = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e5m2) - state_dict = { - "model.diffusion_model.layer1.weight": fp8_weight, - "model.diffusion_model.layer1.scale_weight": torch.tensor(1.0), - "model.diffusion_model.layer2.weight": torch.randn(20, 30, dtype=torch.bfloat16), - } - result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") - - self.assertIsNotNone(result) - self.assertIn("layer1", result) - self.assertEqual(result["layer1"]["format"], "fp8_e5m2_scaled") - - def test_invalid_metadata(self): - """Test with invalid metadata format""" - state_dict = { - "model.diffusion_model._quantization_metadata": "invalid_string", - "model.diffusion_model.layer1.weight": torch.randn(10, 20), - } - result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") - # Should return None on invalid metadata - self.assertIsNone(result) - - def test_different_prefix(self): - """Test with different model prefix (audio model)""" - # Create FP8 tensor by converting from float32 - fp8_weight = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - state_dict = { - "model.model.layer1.weight": fp8_weight, - "model.model.layer1.scale_weight": torch.tensor(1.0), - "model.model.layer2.weight": torch.randn(20, 30, dtype=torch.bfloat16), - } - result = model_detection.detect_layer_quantization(state_dict, "model.model.") - - self.assertIsNotNone(result) - self.assertIn("layer1", result) - - -class TestPickOperationsIntegration(unittest.TestCase): - """Test pick_operations with model_config parameter""" - - def test_backward_compatibility(self): - """Test that pick_operations works without model_config (legacy)""" - from comfy import ops - - # Should work without model_config parameter - result = ops.pick_operations(None, None) - self.assertIsNotNone(result) - self.assertEqual(result, ops.disable_weight_init) - - def test_with_model_config_no_quant(self): - """Test with model_config but no quantization""" - from comfy import ops, supported_models_base - - model_config = supported_models_base.BASE({}) - model_config.layer_quant_config = None - - result = ops.pick_operations(None, None, model_config=model_config) - self.assertIsNotNone(result) - # Should use standard path - self.assertEqual(result, ops.disable_weight_init) - - def test_legacy_scaled_fp8(self): - """Test that legacy scaled_fp8 still works""" - from comfy import ops, supported_models_base - - model_config = supported_models_base.BASE({}) - model_config.scaled_fp8 = torch.float8_e4m3fn - - result = ops.pick_operations( - None, None, - scaled_fp8=torch.float8_e4m3fn, - model_config=model_config - ) - self.assertIsNotNone(result) - # Should return scaled_fp8_ops (the returned class is the inner class) - # Check that it's not the standard disable_weight_init - self.assertNotEqual(result, ops.disable_weight_init) - # Verify it has Linear class - self.assertTrue(hasattr(result, 'Linear')) - - -if __name__ == "__main__": - unittest.main() - diff --git a/tests-unit/comfy_test/test_quant_registry.py b/tests-unit/comfy_test/test_quant_registry.py deleted file mode 100644 index 5c624b1db..000000000 --- a/tests-unit/comfy_test/test_quant_registry.py +++ /dev/null @@ -1,399 +0,0 @@ -""" -Unit tests for tensor subclass quantization system. -Tests the new QuantizedTensorFP8 subclass and operation handlers. -""" - -import unittest -import torch -import sys -import os - -# Add comfy to path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) - -from comfy import ops -from comfy import quant_ops - - -class TestQuantizedTensorFP8(unittest.TestCase): - """Test the QuantizedTensorFP8 tensor subclass""" - - def test_creation(self): - """Test creating a QuantizedTensorFP8""" - fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(2.0) - - qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.bfloat16) - - self.assertIsInstance(qt, quant_ops.QuantizedTensorFP8) - self.assertEqual(qt.shape, (256, 128)) - self.assertEqual(qt.dtype, torch.float8_e4m3fn) - self.assertEqual(qt._scale, scale) - self.assertEqual(qt._orig_dtype, torch.bfloat16) - - def test_dequantize(self): - """Test explicit dequantization""" - # Create a simple FP8 tensor - fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(3.0) - - qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.float32) - dequantized = qt.dequantize() - - # Dequantized should be approximately ones * 3.0 - self.assertEqual(dequantized.dtype, torch.float32) - self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) - - def test_repr(self): - """Test string representation""" - fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(2.5) - - qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.bfloat16) - repr_str = repr(qt) - - self.assertIn("QuantizedTensorFP8", repr_str) - self.assertIn("shape", repr_str) - self.assertIn("scale", repr_str) - - -class TestOperationRegistry(unittest.TestCase): - """Test the operation registry system""" - - def test_registry_basics(self): - """Test that operations are registered""" - registered_ops = quant_ops.list_registered_ops() - - # Check that key operations are registered - self.assertIn(torch.ops.aten.linear.default, registered_ops) - self.assertIn(torch.ops.aten.silu.default, registered_ops) - self.assertIn(torch.ops.aten.layer_norm.default, registered_ops) - self.assertIn(torch.ops.aten.add.Tensor, registered_ops) - self.assertIn(torch.ops.aten.mul.Tensor, registered_ops) - - def test_get_handler(self): - """Test getting a registered handler""" - handler = quant_ops.get_quant_handler(torch.ops.aten.linear.default) - self.assertIsNotNone(handler) - self.assertTrue(callable(handler)) - - def test_custom_registration(self): - """Test registering a custom operation""" - - # Define a custom handler - @quant_ops.register_quant_op(torch.ops.aten.relu.default) - def custom_relu_handler(func, args, kwargs): - return func(*args, **kwargs) - - # Verify registration - handler = quant_ops.get_quant_handler(torch.ops.aten.relu.default) - self.assertIsNotNone(handler) - self.assertEqual(handler, custom_relu_handler) - - -class TestLinearHandler(unittest.TestCase): - """Test the linear operation handler""" - - def test_linear_with_quantized_weight(self): - """Test F.linear with quantized weight""" - # Set seed for reproducibility - torch.manual_seed(42) - - # Create quantized weight - weight_fp32 = torch.randn(256, 128, dtype=torch.float32) - scale = torch.tensor(2.0) - weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn) - weight_q = quant_ops.QuantizedTensorFP8(weight_fp8, scale, orig_dtype=torch.float32) - - # Create input - input_tensor = torch.randn(16, 128, dtype=torch.float32) - - # Call linear (should trigger dispatch) - output = torch.nn.functional.linear(input_tensor, weight_q, bias=None) - - # Verify output shape - self.assertEqual(output.shape, (16, 256)) - - # Verify it's approximately correct (allowing for FP8 quantization error) - # Note: FP8 has limited precision, so use very loose tolerance - expected = torch.nn.functional.linear(input_tensor, weight_fp32, bias=None) - # Just check that it's in the right ballpark (within 50% error on average) - mean_rel_error = ((output - expected).abs() / (expected.abs() + 1e-6)).mean() - self.assertLess(mean_rel_error, 0.5, f"Mean relative error {mean_rel_error:.3f} is too large") - - def test_linear_with_bias(self): - """Test F.linear with quantized weight and bias""" - weight_fp32 = torch.randn(64, 32, dtype=torch.float32) - scale = torch.tensor(1.5) - weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn) - weight_q = quant_ops.QuantizedTensorFP8(weight_fp8, scale, orig_dtype=torch.float32) - - input_tensor = torch.randn(8, 32, dtype=torch.float32) - bias = torch.randn(64, dtype=torch.float32) - - output = torch.nn.functional.linear(input_tensor, weight_q, bias) - - self.assertEqual(output.shape, (8, 64)) - - -class TestActivationHandlers(unittest.TestCase): - """Test activation function handlers""" - - def test_silu_with_quantized_input(self): - """Test SiLU with quantized input""" - # Create quantized input - input_fp32 = torch.randn(16, 128, dtype=torch.float32) - scale = torch.tensor(1.0) - input_fp8 = (input_fp32 / scale).to(torch.float8_e4m3fn) - input_q = quant_ops.QuantizedTensorFP8(input_fp8, scale, orig_dtype=torch.float32) - - # Apply SiLU - output = torch.nn.functional.silu(input_q) - - # Should return a QuantizedTensorFP8 - self.assertIsInstance(output, quant_ops.QuantizedTensorFP8) - - # Verify approximate correctness - expected = torch.nn.functional.silu(input_fp32) - output_dq = output.dequantize() - self.assertTrue(torch.allclose(output_dq, expected, rtol=0.2, atol=0.2)) - - def test_layernorm_dequantizes(self): - """Test that LayerNorm dequantizes input""" - # Create quantized input - input_fp32 = torch.randn(16, 128, dtype=torch.float32) - scale = torch.tensor(1.0) - input_fp8 = (input_fp32 / scale).to(torch.float8_e4m3fn) - input_q = quant_ops.QuantizedTensorFP8(input_fp8, scale, orig_dtype=torch.float32) - - # Apply LayerNorm - weight = torch.ones(128) - bias = torch.zeros(128) - output = torch.nn.functional.layer_norm(input_q, (128,), weight, bias) - - # Should NOT be quantized (LayerNorm breaks quantization) - self.assertNotIsInstance(output, quant_ops.QuantizedTensorFP8) - self.assertEqual(output.dtype, torch.float32) - - -class TestElementwiseHandlers(unittest.TestCase): - """Test element-wise operation handlers""" - - def test_add_mixed_tensors(self): - """Test addition with mixed quantized/non-quantized tensors""" - # Create quantized tensor - a_fp32 = torch.ones(10, 20, dtype=torch.float32) - scale = torch.tensor(1.0) - a_fp8 = (a_fp32 / scale).to(torch.float8_e4m3fn) - a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale, orig_dtype=torch.float32) - - # Non-quantized tensor - b = torch.ones(10, 20, dtype=torch.float32) * 2.0 - - # Add them - result = a_q + b - - # Should be dequantized - self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8) - self.assertTrue(torch.allclose(result, torch.ones(10, 20) * 3.0, rtol=0.1)) - - def test_mul_quantized_tensors(self): - """Test multiplication of two quantized tensors""" - a_fp32 = torch.ones(10, 20) * 2.0 - scale_a = torch.tensor(1.0) - a_fp8 = (a_fp32 / scale_a).to(torch.float8_e4m3fn) - a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale_a, orig_dtype=torch.float32) - - b_fp32 = torch.ones(10, 20) * 3.0 - scale_b = torch.tensor(1.0) - b_fp8 = (b_fp32 / scale_b).to(torch.float8_e4m3fn) - b_q = quant_ops.QuantizedTensorFP8(b_fp8, scale_b, orig_dtype=torch.float32) - - result = a_q * b_q - - # Should be dequantized - self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8) - self.assertTrue(torch.allclose(result, torch.ones(10, 20) * 6.0, rtol=0.1)) - - -class TestFallbackMechanism(unittest.TestCase): - """Test fallback for unsupported operations""" - - def test_unsupported_op_dequantizes(self): - """Test that unsupported operations fall back to dequantization""" - # Set seed for reproducibility - torch.manual_seed(42) - - # Create quantized tensor - a_fp32 = torch.randn(10, 20, dtype=torch.float32) - scale = torch.tensor(1.0) - a_fp8 = (a_fp32 / scale).to(torch.float8_e4m3fn) - a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale, orig_dtype=torch.float32) - - # Call an operation that doesn't have a registered handler - # For example, torch.abs - result = torch.abs(a_q) - - # Should work via fallback (dequantize → abs → return) - self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8) - expected = torch.abs(a_fp32) - # FP8 introduces quantization error, so use loose tolerance - mean_error = (result - expected).abs().mean() - self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large") - - -class TestMixedPrecisionOps(unittest.TestCase): - """Test MixedPrecisionOps integration""" - - def test_linear_layer_creation(self): - """Test that MixedPrecisionOps.Linear can be created""" - layer = ops.MixedPrecisionOps.Linear(128, 256, bias=True, device="cpu", dtype=torch.float32) - - self.assertIsInstance(layer, ops.MixedPrecisionOps.Linear) - self.assertFalse(layer._quantization_initialized) - self.assertIsNone(layer.quant_format) - - def test_layer_quant_config_detection(self): - """Test that layer quantization config is detected during load""" - # Set up layer config - ops.MixedPrecisionOps._layer_quant_config = { - "test_layer": { - "format": "fp8_e4m3fn", - "params": {} - } - } - - # Create a state dict with quantized weight - weight_fp32 = torch.randn(256, 128, dtype=torch.float32) - scale = torch.tensor(2.0) - weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn) - - state_dict = { - "model.diffusion_model.test_layer.weight": weight_fp8, - "model.diffusion_model.test_layer.scale_weight": scale, - } - - # Create layer and load - layer = ops.MixedPrecisionOps.Linear(128, 256, bias=False, device="cpu", dtype=torch.float8_e4m3fn) - layer.weight = torch.nn.Parameter(torch.zeros(256, 128, dtype=torch.float8_e4m3fn)) - - # Manually call _load_from_state_dict - layer._load_from_state_dict( - state_dict, - prefix="model.diffusion_model.test_layer.", - local_metadata={}, - strict=True, - missing_keys=[], - unexpected_keys=[], - error_msgs=[] - ) - - # Verify quantization was initialized - self.assertTrue(layer._quantization_initialized) - self.assertEqual(layer.quant_format, "fp8_e4m3fn") - self.assertIsNotNone(layer.quant_scale) - - # Verify weight is wrapped - self.assertIsInstance(layer.weight.data, quant_ops.QuantizedTensorFP8) - - # Clean up - ops.MixedPrecisionOps._layer_quant_config = {} - - -class TestBackwardCompatibility(unittest.TestCase): - """Test backward compatibility with legacy systems""" - - def test_legacy_ops_classes_exist(self): - """Test that legacy ops classes still exist""" - self.assertTrue(hasattr(ops, 'disable_weight_init')) - self.assertTrue(hasattr(ops, 'manual_cast')) - self.assertTrue(hasattr(ops, 'fp8_ops')) - self.assertTrue(hasattr(ops, 'scaled_fp8_ops')) - - def test_pick_operations_legacy_path(self): - """Test pick_operations returns correct class for legacy cases""" - # Test standard case - result = ops.pick_operations(torch.float32, torch.float32) - self.assertEqual(result, ops.disable_weight_init) - - # Test manual cast case - result = ops.pick_operations(torch.float32, torch.float16) - self.assertEqual(result, ops.manual_cast) - - -class TestFP8LinearUnification(unittest.TestCase): - """Test that fp8_linear now uses the unified tensor subclass infrastructure""" - - @unittest.skipIf(not torch.cuda.is_available(), "CUDA required for FP8") - def test_fp8_linear_uses_tensor_subclass(self): - """Verify fp8_linear wraps tensors in QuantizedTensorFP8""" - torch.manual_seed(42) - - # Create a mock Linear layer with FP8 weight - linear = ops.fp8_ops.Linear(4, 3, bias=True) - linear.weight = torch.nn.Parameter( - torch.randn(3, 4, dtype=torch.bfloat16).to(torch.float8_e4m3fn), - requires_grad=False - ) - linear.bias = torch.nn.Parameter( - torch.randn(3, dtype=torch.bfloat16), - requires_grad=False - ) - linear.scale_weight = torch.tensor(1.0) - linear.scale_input = None # No input scaling - - # Create input - input_tensor = torch.randn(2, 4, dtype=torch.bfloat16) - - # Call fp8_linear - should work without errors - try: - result = ops.fp8_linear(linear, input_tensor) - self.assertIsNotNone(result) - self.assertEqual(result.shape, (2, 3)) - except Exception as e: - # On CPU or unsupported hardware, _scaled_mm might not be available - # but the function should still complete without syntax errors - pass - - def test_fp8_linear_maintains_signature(self): - """Verify fp8_linear maintains its original function signature""" - import inspect - sig = inspect.signature(ops.fp8_linear) - params = list(sig.parameters.keys()) - - # Should have 'self' and 'input' parameters - self.assertIn('self', params) - self.assertIn('input', params) - self.assertEqual(len(params), 2) - - def test_fp8_linear_returns_none_for_non_fp8(self): - """Verify fp8_linear returns None for non-FP8 weights""" - # Create a Linear layer with BF16 weight (not FP8) - linear = ops.disable_weight_init.Linear(4, 3, bias=False) - linear.weight = torch.nn.Parameter( - torch.randn(3, 4, dtype=torch.bfloat16), - requires_grad=False - ) - - input_tensor = torch.randn(2, 4, dtype=torch.bfloat16) - - # Should return None for non-FP8 weights - result = ops.fp8_linear(linear, input_tensor) - self.assertIsNone(result) - - def test_fp8_ops_linear_uses_fp8_linear(self): - """Verify fp8_ops.Linear still uses fp8_linear in forward pass""" - linear = ops.fp8_ops.Linear(4, 3, bias=False) - - # Verify the class has the forward_comfy_cast_weights method - self.assertTrue(hasattr(linear, 'forward_comfy_cast_weights')) - - # The forward_comfy_cast_weights should attempt to call fp8_linear - # (we can't easily test this without mocking, but we verify structure) - import inspect - source = inspect.getsource(linear.forward_comfy_cast_weights) - self.assertIn('fp8_linear', source) - - -if __name__ == "__main__": - unittest.main() From 2a8b8264426c311eebee0ec9eb167f20f678c952 Mon Sep 17 00:00:00 2001 From: lspindler Date: Mon, 27 Oct 2025 08:52:50 +0100 Subject: [PATCH 7/7] ruff lint --- comfy/model_base.py | 2 +- comfy/model_detection.py | 2 +- comfy/ops.py | 20 ++--- comfy/quant_ops.py | 76 +++++++++---------- .../comfy_quant/test_mixed_precision.py | 68 ++++++++--------- tests-unit/comfy_quant/test_quant_registry.py | 67 ++++++++-------- 6 files changed, 117 insertions(+), 118 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 7b4651f8e..f850cc402 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -326,7 +326,7 @@ class BaseModel(torch.nn.Module): if self.model_config.scaled_fp8 is not None: unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8) - + # Save mixed precision metadata if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config: metadata = { diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 335ccbd17..c4fc27742 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -16,7 +16,7 @@ def detect_layer_quantization(metadata): logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})") return quant_metadata["layers"] else: - raise ValueError(f"Invalid quantization metadata format") + raise ValueError("Invalid quantization metadata format") return None diff --git a/comfy/ops.py b/comfy/ops.py index 8af1e949d..e2d76d7a9 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -325,7 +325,7 @@ class manual_cast(disable_weight_init): def fp8_linear(self, input): """ - Legacy FP8 linear function for backward compatibility. + Legacy FP8 linear function for backward compatibility. Uses QuantizedTensor subclass for dispatch. """ dtype = self.weight.dtype @@ -339,7 +339,7 @@ def fp8_linear(self, input): input_shape = input.shape input_dtype = input.dtype - + if len(input.shape) == 3: w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) @@ -354,14 +354,14 @@ def fp8_linear(self, input): scale_input = torch.ones((), device=input.device, dtype=torch.float32) else: scale_input = scale_input.to(input.device) - + # Wrap weight in QuantizedTensor - this enables unified dispatch # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype} quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight) quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, fp8_dtype=dtype) o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) - + if tensor_2d: return o.reshape(input_shape[0], -1) return o.reshape((-1, input_shape[1], self.weight.shape[0])) @@ -503,8 +503,8 @@ class MixedPrecisionOps(disable_weight_init): def reset_parameters(self): return None - - def _load_from_state_dict(self, state_dict, prefix, local_metadata, + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): device = self.factory_kwargs["device"] @@ -520,10 +520,10 @@ class MixedPrecisionOps(disable_weight_init): quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None) if quant_format is None: raise ValueError(f"Unknown quantization format for layer {layer_name}") - + mixin = QUANT_FORMAT_MIXINS[quant_format] self.layout_type = mixin["layout_type"] - + layout_params = { 'scale': state_dict.pop(f"{prefix}weight_scale", None), 'orig_dtype': MixedPrecisionOps._compute_dtype @@ -558,7 +558,7 @@ class MixedPrecisionOps(disable_weight_init): not isinstance(input, QuantizedTensor)): input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype) return self._forward(input, self.weight, self.bias) - + def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: @@ -566,7 +566,7 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_ MixedPrecisionOps._compute_dtype = compute_dtype logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") return MixedPrecisionOps - + fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 96d2fa03f..aa1a231bd 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -31,7 +31,7 @@ def register_generic_util(torch_op): Decorator to register a generic utility that works for all layouts. Args: torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default) - + Example: @register_generic_util(torch.ops.aten.detach.default) def generic_detach(func, args, kwargs): @@ -78,10 +78,10 @@ def _copy_layout_params(params): class QuantizedLayout: """ Base class for quantization layouts. - + A layout encapsulates the format-specific logic for quantization/dequantization and provides a uniform interface for extracting raw tensors needed for computation. - + New quantization formats should subclass this and implement the required methods. """ @classmethod @@ -90,8 +90,8 @@ class QuantizedLayout: @staticmethod def dequantize(qdata, **layout_params) -> torch.Tensor: - raise NotImplementedError(f"TensorLayout must implement dequantize()") - + raise NotImplementedError("TensorLayout must implement dequantize()") + @classmethod def get_plain_tensors(cls, qtensor) -> torch.Tensor: raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()") @@ -100,45 +100,45 @@ class QuantizedLayout: class QuantizedTensor(torch.Tensor): """ Universal quantized tensor that works with any layout. - + This tensor subclass uses a pluggable layout system to support multiple quantization formats (FP8, INT4, INT8, etc.) without code duplication. - + The layout_type determines format-specific behavior, while common operations (detach, clone, to) are handled generically. - + Attributes: _qdata: The quantized tensor data _layout_type: Layout class (e.g., TensorCoreFP8Layout) _layout_params: Dict with layout-specific params (scale, zero_point, etc.) """ - + @staticmethod def __new__(cls, qdata, layout_type, layout_params): """ Create a quantized tensor. - + Args: qdata: The quantized data tensor layout_type: Layout class (subclass of QuantizedLayout) layout_params: Dict with layout-specific parameters """ return torch.Tensor._make_subclass(cls, qdata, require_grad=False) - + def __init__(self, qdata, layout_type, layout_params): self._qdata = qdata.contiguous() self._layout_type = layout_type self._layout_params = layout_params - + def __repr__(self): layout_name = self._layout_type.__name__ param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2]) return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})" - + @property def layout_type(self): return self._layout_type - + def __tensor_flatten__(self): """ Tensor flattening protocol for proper device movement. @@ -147,7 +147,7 @@ class QuantizedTensor(torch.Tensor): ctx = { "layout_type": self._layout_type, } - + tensor_params = {} non_tensor_params = {} for k, v in self._layout_params.items(): @@ -155,17 +155,17 @@ class QuantizedTensor(torch.Tensor): tensor_params[k] = v else: non_tensor_params[k] = v - + ctx["tensor_param_keys"] = list(tensor_params.keys()) ctx["non_tensor_params"] = non_tensor_params - + for k, v in tensor_params.items(): attr_name = f"_layout_param_{k}" object.__setattr__(self, attr_name, v) inner_tensors.append(attr_name) - + return inner_tensors, ctx - + @staticmethod def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): """ @@ -174,41 +174,41 @@ class QuantizedTensor(torch.Tensor): """ layout_type = ctx["layout_type"] layout_params = dict(ctx["non_tensor_params"]) - + for key in ctx["tensor_param_keys"]: attr_name = f"_layout_param_{key}" layout_params[key] = inner_tensors[attr_name] - + return QuantizedTensor(inner_tensors["_q_data"], layout_type, layout_params) - + @classmethod def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor': qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs) return cls(qdata, layout_type, layout_params) - + def dequantize(self) -> torch.Tensor: return self._layout_type.dequantize(self._qdata, **self._layout_params) - + @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): kwargs = kwargs or {} - + # Step 1: Check generic utilities first (detach, clone, to, etc.) if func in _GENERIC_UTILS: return _GENERIC_UTILS[func](func, args, kwargs) - + # Step 2: Check layout-specific handlers (linear, matmul, etc.) layout_type = _get_layout_from_args(args) if layout_type and func in _LAYOUT_REGISTRY: handler = _LAYOUT_REGISTRY[func].get(layout_type) if handler: return handler(func, args, kwargs) - + # Step 3: Fallback to dequantization if isinstance(args[0] if args else None, QuantizedTensor): logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") return cls._dequant_and_fallback(func, args, kwargs) - + @classmethod def _dequant_and_fallback(cls, func, args, kwargs): def dequant_arg(arg): @@ -217,7 +217,7 @@ class QuantizedTensor(torch.Tensor): elif isinstance(arg, (list, tuple)): return type(arg)(dequant_arg(a) for a in arg) return arg - + new_args = dequant_arg(args) new_kwargs = dequant_arg(kwargs) return func(*new_args, **new_kwargs) @@ -239,13 +239,13 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout= f"QuantizedTensor: dtype conversion requested to {target_dtype}, " f"but not supported for quantized tensors. Ignoring dtype." ) - + if target_layout is not None and target_layout != torch.strided: logging.warning( f"QuantizedTensor: layout change requested to {target_layout}, " f"but not supported. Ignoring layout." ) - + # Handle device transfer current_device = qt._qdata.device if target_device is not None: @@ -254,7 +254,7 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout= target_device = torch.device(target_device) if isinstance(current_device, str): current_device = torch.device(current_device) - + if target_device != current_device: logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}") new_q_data = qt._qdata.to(device=target_device) @@ -262,7 +262,7 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout= new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params) logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}") return new_qt - + logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original") return qt @@ -318,7 +318,7 @@ def generic_to_dtype_layout(func, args, kwargs): def generic_copy_(func, args, kwargs): qt_dest = args[0] src = args[1] - + if isinstance(qt_dest, QuantizedTensor): if isinstance(src, QuantizedTensor): # Copy from another quantized tensor @@ -383,15 +383,15 @@ def fp8_linear(func, args, kwargs): input_tensor = args[0] weight = args[1] bias = args[2] if len(args) > 2 else None - + if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) - + out_dtype = kwargs.get("out_dtype") if out_dtype is None: out_dtype = input_tensor._layout_params['orig_dtype'] - + weight_t = plain_weight.t() tensor_2d = False @@ -424,7 +424,7 @@ def fp8_linear(func, args, kwargs): return QuantizedTensor(output, TensorCoreFP8Layout, output_params) else: return output - + except Exception as e: raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index e34552760..1102f9bd4 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -16,7 +16,7 @@ class SimpleModel(torch.nn.Module): self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16) self.layer2 = operations.Linear(20, 30, device="cpu", dtype=torch.bfloat16) self.layer3 = operations.Linear(30, 40, device="cpu", dtype=torch.bfloat16) - + def forward(self, x): x = self.layer1(x) x = torch.nn.functional.relu(x) @@ -32,10 +32,10 @@ class TestMixedPrecisionOps(unittest.TestCase): """Test that model with no quantization works normally""" # Configure no quantization ops.MixedPrecisionOps._layer_quant_config = {} - + # Create model model = SimpleModel(operations=ops.MixedPrecisionOps) - + # Initialize weights manually model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16)) model.layer1.bias = torch.nn.Parameter(torch.randn(20, dtype=torch.bfloat16)) @@ -43,19 +43,19 @@ class TestMixedPrecisionOps(unittest.TestCase): model.layer2.bias = torch.nn.Parameter(torch.randn(30, dtype=torch.bfloat16)) model.layer3.weight = torch.nn.Parameter(torch.randn(40, 30, dtype=torch.bfloat16)) model.layer3.bias = torch.nn.Parameter(torch.randn(40, dtype=torch.bfloat16)) - + # Initialize weight_function and bias_function for layer in [model.layer1, model.layer2, model.layer3]: layer.weight_function = [] layer.bias_function = [] - + # Forward pass input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) output = model(input_tensor) - + self.assertEqual(output.shape, (5, 40)) self.assertEqual(output.dtype, torch.bfloat16) - + def test_mixed_precision_load(self): """Test loading a mixed precision model from state dict""" # Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard @@ -70,52 +70,52 @@ class TestMixedPrecisionOps(unittest.TestCase): } } ops.MixedPrecisionOps._layer_quant_config = layer_quant_config - + # Create state dict with mixed precision fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e4m3fn) - + state_dict = { # Layer 1: FP8 E4M3FN "layer1.weight": fp8_weight1, "layer1.bias": torch.randn(20, dtype=torch.bfloat16), "layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32), - + # Layer 2: Standard BF16 "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), "layer2.bias": torch.randn(30, dtype=torch.bfloat16), - + # Layer 3: FP8 E4M3FN "layer3.weight": fp8_weight3, "layer3.bias": torch.randn(40, dtype=torch.bfloat16), "layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32), } - + # Create model and load state dict (strict=False because custom loading pops keys) model = SimpleModel(operations=ops.MixedPrecisionOps) model.load_state_dict(state_dict, strict=False) - + # Verify weights are wrapped in QuantizedTensor self.assertIsInstance(model.layer1.weight, QuantizedTensor) self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout) - + # Layer 2 should NOT be quantized self.assertNotIsInstance(model.layer2.weight, QuantizedTensor) - + # Layer 3 should be quantized self.assertIsInstance(model.layer3.weight, QuantizedTensor) self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout) - + # Verify scales were loaded self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0) self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5) - + # Forward pass input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) output = model(input_tensor) - + self.assertEqual(output.shape, (5, 40)) - + def test_state_dict_quantized_preserved(self): """Test that quantized weights are preserved in state_dict()""" # Configure mixed precision @@ -126,7 +126,7 @@ class TestMixedPrecisionOps(unittest.TestCase): } } ops.MixedPrecisionOps._layer_quant_config = layer_quant_config - + # Create and load model fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) state_dict1 = { @@ -138,22 +138,22 @@ class TestMixedPrecisionOps(unittest.TestCase): "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - + model = SimpleModel(operations=ops.MixedPrecisionOps) model.load_state_dict(state_dict1, strict=False) - + # Save state dict state_dict2 = model.state_dict() - + # Verify layer1.weight is a QuantizedTensor with scale preserved self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor) self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0) self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout) - + # Verify non-quantized layers are standard tensors self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor) self.assertNotIsInstance(state_dict2["layer3.weight"], QuantizedTensor) - + def test_weight_function_compatibility(self): """Test that weight_function (LoRA) works with quantized layers""" # Configure FP8 quantization @@ -164,7 +164,7 @@ class TestMixedPrecisionOps(unittest.TestCase): } } ops.MixedPrecisionOps._layer_quant_config = layer_quant_config - + # Create and load model fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) state_dict = { @@ -176,24 +176,24 @@ class TestMixedPrecisionOps(unittest.TestCase): "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - + model = SimpleModel(operations=ops.MixedPrecisionOps) model.load_state_dict(state_dict, strict=False) - + # Add a weight function (simulating LoRA) # This should trigger dequantization during forward pass def apply_lora(weight): lora_delta = torch.randn_like(weight) * 0.01 return weight + lora_delta - + model.layer1.weight_function.append(apply_lora) - + # Forward pass should work with LoRA (triggers weight_function path) input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) output = model(input_tensor) - + self.assertEqual(output.shape, (5, 40)) - + def test_error_handling_unknown_format(self): """Test that unknown formats raise error""" # Configure with unknown format @@ -204,7 +204,7 @@ class TestMixedPrecisionOps(unittest.TestCase): } } ops.MixedPrecisionOps._layer_quant_config = layer_quant_config - + # Create state dict state_dict = { "layer1.weight": torch.randn(20, 10, dtype=torch.bfloat16), @@ -214,7 +214,7 @@ class TestMixedPrecisionOps(unittest.TestCase): "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - + # Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS model = SimpleModel(operations=ops.MixedPrecisionOps) with self.assertRaises(KeyError): diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py index 263581417..26e91a7ee 100644 --- a/tests-unit/comfy_quant/test_quant_registry.py +++ b/tests-unit/comfy_quant/test_quant_registry.py @@ -11,51 +11,51 @@ from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout class TestQuantizedTensor(unittest.TestCase): """Test the QuantizedTensor subclass with FP8 layout""" - + def test_creation(self): """Test creating a QuantizedTensor with TensorCoreFP8Layout""" fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) scale = torch.tensor(2.0) layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16} - + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) - + self.assertIsInstance(qt, QuantizedTensor) self.assertEqual(qt.shape, (256, 128)) self.assertEqual(qt.dtype, torch.float8_e4m3fn) self.assertEqual(qt._layout_params['scale'], scale) self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16) self.assertEqual(qt._layout_type, TensorCoreFP8Layout) - + def test_dequantize(self): """Test explicit dequantization""" fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) scale = torch.tensor(3.0) layout_params = {'scale': scale, 'orig_dtype': torch.float32} - + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) dequantized = qt.dequantize() - + self.assertEqual(dequantized.dtype, torch.float32) self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) - + def test_from_float(self): """Test creating QuantizedTensor from float tensor""" float_tensor = torch.randn(64, 32, dtype=torch.float32) scale = torch.tensor(1.5) - + qt = QuantizedTensor.from_float( - float_tensor, - TensorCoreFP8Layout, + float_tensor, + TensorCoreFP8Layout, scale=scale, fp8_dtype=torch.float8_e4m3fn ) - + self.assertIsInstance(qt, QuantizedTensor) self.assertEqual(qt.dtype, torch.float8_e4m3fn) self.assertEqual(qt.shape, (64, 32)) - + # Verify dequantization gives approximately original values dequantized = qt.dequantize() mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean() @@ -64,48 +64,48 @@ class TestQuantizedTensor(unittest.TestCase): class TestGenericUtilities(unittest.TestCase): """Test generic utility operations""" - + def test_detach(self): """Test detach operation on quantized tensor""" fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) scale = torch.tensor(1.5) layout_params = {'scale': scale, 'orig_dtype': torch.float32} qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) - + # Detach should return a new QuantizedTensor qt_detached = qt.detach() - + self.assertIsInstance(qt_detached, QuantizedTensor) self.assertEqual(qt_detached.shape, qt.shape) self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout) - + def test_clone(self): """Test clone operation on quantized tensor""" fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) scale = torch.tensor(1.5) layout_params = {'scale': scale, 'orig_dtype': torch.float32} qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) - + # Clone should return a new QuantizedTensor qt_cloned = qt.clone() - + self.assertIsInstance(qt_cloned, QuantizedTensor) self.assertEqual(qt_cloned.shape, qt.shape) self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout) - + # Verify it's a deep copy self.assertIsNot(qt_cloned._qdata, qt._qdata) - + def test_to_device(self): """Test device transfer""" fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) scale = torch.tensor(1.5) layout_params = {'scale': scale, 'orig_dtype': torch.float32} qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) - + # Moving to same device should work (CPU to CPU) qt_cpu = qt.to('cpu') - + self.assertIsInstance(qt_cpu, QuantizedTensor) self.assertEqual(qt_cpu.device.type, 'cpu') self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu') @@ -113,64 +113,63 @@ class TestGenericUtilities(unittest.TestCase): class TestTensorCoreFP8Layout(unittest.TestCase): """Test the TensorCoreFP8Layout implementation""" - + def test_quantize(self): """Test quantization method""" float_tensor = torch.randn(32, 64, dtype=torch.float32) scale = torch.tensor(1.5) - + qdata, layout_params = TensorCoreFP8Layout.quantize( float_tensor, scale=scale, fp8_dtype=torch.float8_e4m3fn ) - + self.assertEqual(qdata.dtype, torch.float8_e4m3fn) self.assertEqual(qdata.shape, float_tensor.shape) self.assertIn('scale', layout_params) self.assertIn('orig_dtype', layout_params) self.assertEqual(layout_params['orig_dtype'], torch.float32) - + def test_dequantize(self): """Test dequantization method""" float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0 scale = torch.tensor(1.0) - + qdata, layout_params = TensorCoreFP8Layout.quantize( float_tensor, scale=scale, fp8_dtype=torch.float8_e4m3fn ) - + dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params) - + # Should approximately match original self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1)) class TestFallbackMechanism(unittest.TestCase): """Test fallback for unsupported operations""" - + def test_unsupported_op_dequantizes(self): """Test that unsupported operations fall back to dequantization""" # Set seed for reproducibility torch.manual_seed(42) - + # Create quantized tensor a_fp32 = torch.randn(10, 20, dtype=torch.float32) scale = torch.tensor(1.0) - layout_params = {'scale': scale, 'orig_dtype': torch.float32} a_q = QuantizedTensor.from_float( a_fp32, TensorCoreFP8Layout, scale=scale, fp8_dtype=torch.float8_e4m3fn ) - + # Call an operation that doesn't have a registered handler # For example, torch.abs result = torch.abs(a_q) - + # Should work via fallback (dequantize → abs → return) self.assertNotIsInstance(result, QuantizedTensor) expected = torch.abs(a_fp32)