Implement mixed precision operations with a registry design and metadate for quant spec in checkpoint.

This commit is contained in:
lspindler 2025-10-16 16:07:43 +02:00
parent c170fd2db5
commit 70acf79346
4 changed files with 618 additions and 2 deletions

View File

@ -134,7 +134,7 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
fp8 = model_config.optimizations.get("fp8", False)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
else:
operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
@ -333,6 +333,14 @@ class BaseModel(torch.nn.Module):
if self.model_config.scaled_fp8 is not None:
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
# Save mixed precision metadata
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
metadata = {
"format_version": "1.0",
"layers": self.model_config.layer_quant_config
}
unet_state_dict["_quantization_metadata"] = metadata
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
if self.model_type == ModelType.V_PREDICTION:

View File

@ -6,6 +6,125 @@ import math
import logging
import torch
# ==============================================================================
# Quantization Detection Functions
# ==============================================================================
def normalize_layer_name(full_key, known_prefixes):
"""
Strip model prefix and parameter suffix from a state dict key.
Args:
full_key: Full state dict key (e.g., "model.diffusion_model.layer1.weight")
known_prefixes: List of known model prefixes to strip
Returns:
Normalized layer name (e.g., "layer1")
"""
name = full_key
# Strip model prefix
for prefix in known_prefixes:
if name.startswith(prefix):
name = name[len(prefix):]
break
# Remove parameter suffix
for suffix in [".weight", ".bias", ".scale_weight", ".scale_input"]:
if name.endswith(suffix):
name = name[:-len(suffix)]
break
return name
def detect_layer_quantization(state_dict, prefix="model.diffusion_model."):
"""
Detect per-layer quantization configuration from state dict.
Detection priority:
1. Check for _quantization_metadata key (new format)
2. Check for scaled_fp8 key (legacy format - return None)
3. Check for per-layer scale_weight patterns (mixed detection)
4. No quantization detected (return None)
Args:
state_dict: Model state dictionary
prefix: Key prefix for model layers
Returns:
Dict mapping layer names to quantization configs, or None for legacy/no quantization.
Example return value:
{
"input_blocks.5.1.transformer_blocks.0.attn1.to_q": {
"format": "fp8_e4m3fn_scaled",
"params": {"use_fp8_matmul": True}
},
"middle_block.1.transformer_blocks.0.attn2.to_k": {
"format": "fp8_e5m2_scaled",
"params": {"use_fp8_matmul": True}
}
}
"""
# 1. Check for new metadata format
metadata_key = f"{prefix}_quantization_metadata"
if metadata_key in state_dict:
try:
metadata = state_dict.pop(metadata_key)
if isinstance(metadata, dict) and "layers" in metadata:
logging.info(f"Found quantization metadata (version {metadata.get('format_version', 'unknown')})")
return metadata["layers"]
else:
logging.warning(f"Invalid quantization metadata format, ignoring")
except Exception as e:
logging.error(f"Failed to parse quantization metadata: {e}")
return None
# 2. Check for legacy scaled_fp8 marker
# If present, return None to use legacy code path
scaled_fp8_key = f"{prefix}scaled_fp8"
if scaled_fp8_key in state_dict:
logging.debug("Detected legacy scaled_fp8 format, using legacy code path")
return None
# 3. Check for per-layer scale patterns (mixed precision without metadata)
# Look for layers that have scale_weight but not all layers have it
known_prefixes = [prefix]
layer_configs = {}
layers_with_scale = set()
layers_with_weight = set()
for key in state_dict.keys():
if key.startswith(prefix):
if key.endswith(".scale_weight"):
layer_name = normalize_layer_name(key, known_prefixes)
layers_with_scale.add(layer_name)
# Detect format based on weight dtype
weight_key = f"{prefix}{layer_name}.weight"
if weight_key in state_dict:
weight_dtype = state_dict[weight_key].dtype
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
format_name = "fp8_e4m3fn_scaled" if weight_dtype == torch.float8_e4m3fn else "fp8_e5m2_scaled"
layer_configs[layer_name] = {
"format": format_name,
"params": {"use_fp8_matmul": True}
}
elif key.endswith(".weight") and not key.endswith(".scale_weight"):
layer_name = normalize_layer_name(key, known_prefixes)
layers_with_weight.add(layer_name)
# If we found scale_weight on some but not all layers, it's mixed precision
if layer_configs and len(layers_with_scale) < len(layers_with_weight):
logging.info(f"Detected mixed precision via scale patterns: {len(layers_with_scale)} quantized layers, {len(layers_with_weight)} total layers")
return layer_configs
# 4. No quantization detected
return None
def count_blocks(state_dict_keys, prefix_string):
count = 0
while True:
@ -701,6 +820,12 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
else:
model_config.optimizations["fp8"] = True
# Detect per-layer quantization (mixed precision)
layer_quant_config = detect_layer_quantization(state_dict, unet_key_prefix)
if layer_quant_config:
model_config.layer_quant_config = layer_quant_config
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
return model_config
def unet_prefix_from_state_dict(state_dict):

View File

@ -478,7 +478,457 @@ if CUBLAS_IS_AVAILABLE:
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
# ==============================================================================
# Quantization Format Registry System
# ==============================================================================
class QuantFormatHandler:
"""
Base class for all quantization format handlers.
A handler encapsulates the logic for a specific quantization format
(e.g., FP8 scaled, MX formats) and manages the quantization
parameters and forward pass for quantized layers.
"""
def __init__(self, layer, **config):
"""
Initialize handler for a specific layer.
Args:
layer: The nn.Module layer (Linear, Conv2d, etc.)
**config: Format-specific configuration
"""
self.layer = layer
self.config = config
def setup_parameters(self):
"""
Initialize quantization parameters on the layer.
Called during layer construction or load_state_dict.
Subclasses should create parameters like scale_weight, scale_input, etc.
and attach them to self.layer.
"""
raise NotImplementedError(f"{self.__class__.__name__} must implement setup_parameters()")
def forward(self, *args, **kwargs):
"""
Execute quantized forward pass.
Signature matches the layer's expected forward pass.
Handler accesses layer parameters via self.layer (weight, bias, etc.)
Args:
*args: Positional arguments matching layer forward signature
**kwargs: Keyword arguments matching layer forward signature
Returns:
Layer output tensor
Examples:
Linear: forward(input)
Conv2d: forward(input)
GroupNorm: forward(input)
"""
raise NotImplementedError(f"{self.__class__.__name__} must implement forward()")
def load_state_dict(self, state_dict, prefix):
"""
Load quantization parameters from state dict.
Args:
state_dict: State dictionary
prefix: Key prefix for this layer (e.g., "model.diffusion_model.layer1.")
"""
raise NotImplementedError(f"{self.__class__.__name__} must implement load_state_dict()")
def state_dict(self, prefix):
"""
Save quantization parameters to state dict.
Args:
prefix: Key prefix for this layer
Returns:
Dictionary of quantization parameters with full keys
"""
raise NotImplementedError(f"{self.__class__.__name__} must implement state_dict()")
def convert_weight(self, weight, inplace=False):
"""
Convert weight from quantized to full precision (dequantize).
Args:
weight: Quantized weight tensor
inplace: Whether to modify in-place
Returns:
Dequantized weight tensor
"""
raise NotImplementedError(f"{self.__class__.__name__} must implement convert_weight()")
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False):
"""
Convert and set weight from full precision to quantized.
Args:
weight: Full precision weight tensor
inplace_update: Whether to update layer weight in-place
seed: Random seed for stochastic rounding
return_weight: If True, return quantized weight without setting
Returns:
Quantized weight if return_weight=True, else None
"""
raise NotImplementedError(f"{self.__class__.__name__} must implement set_weight()")
class QuantFormatRegistry:
"""
Global registry for quantization formats.
Formats are registered with a unique name and handler class.
Custom formats can be registered by custom nodes.
"""
_formats = {}
@classmethod
def register(cls, name, handler_class, **default_config):
"""
Register a new quantization format.
Args:
name: Unique format identifier (e.g., "fp8_e4m3fn_scaled")
handler_class: Handler class implementing QuantFormatHandler
**default_config: Default configuration parameters
Example:
QuantFormatRegistry.register(
"fp8_e4m3fn_scaled",
handler_class=FP8ScaledHandler,
base_dtype=torch.float8_e4m3fn,
quantize_activation=False,
use_fp8_matmul=True,
)
"""
if not issubclass(handler_class, QuantFormatHandler):
raise TypeError(f"handler_class must be a subclass of QuantFormatHandler, got {handler_class}")
cls._formats[name] = {
"handler": handler_class,
"config": default_config.copy()
}
logging.debug(f"Registered quantization format: {name}")
@classmethod
def get(cls, name, **override_config):
"""
Get format info with optional config overrides.
Args:
name: Format identifier
**override_config: Configuration overrides
Returns:
Dict with 'handler' (class) and 'config' (dict) keys
Raises:
ValueError: If format name not registered
"""
if name not in cls._formats:
available = ", ".join(cls._formats.keys()) if cls._formats else "none"
raise ValueError(f"Unknown quantization format: '{name}'. Available formats: {available}")
format_info = cls._formats[name].copy()
# Merge override_config into default config
config = format_info["config"].copy()
config.update(override_config)
format_info["config"] = config
return format_info
@classmethod
def list_formats(cls):
"""List all registered format names"""
return list(cls._formats.keys())
@classmethod
def is_registered(cls, name):
"""Check if a format is registered"""
return name in cls._formats
class FP8ScaledHandler(QuantFormatHandler):
"""
Handler for FP8 quantization with per-tensor scaling.
Supports both weight-only and weight+activation quantization.
Compatible with existing fp8_linear implementation.
"""
def setup_parameters(self):
"""Initialize scale_weight and optionally scale_input"""
device = self.layer.weight.device
dtype = torch.float32
# Always have scale_weight for FP8
if not hasattr(self.layer, 'scale_weight') or self.layer.scale_weight is None:
self.layer.scale_weight = torch.nn.Parameter(
torch.ones((), device=device, dtype=dtype),
requires_grad=False
)
# scale_input is optional (for activation quantization)
if self.config.get("quantize_activation", False):
if not hasattr(self.layer, 'scale_input') or self.layer.scale_input is None:
self.layer.scale_input = torch.nn.Parameter(
torch.ones((), device=device, dtype=dtype),
requires_grad=False
)
else:
self.layer.scale_input = None
def forward(self, *args, **kwargs):
"""
FP8 forward pass with optional activation quantization.
Supports Linear layers (Conv2d in future).
"""
# Detect layer type and dispatch
if isinstance(self.layer, torch.nn.Linear):
return self._forward_linear(*args, **kwargs)
else:
raise NotImplementedError(
f"FP8ScaledHandler not implemented for {type(self.layer).__name__}"
)
def _forward_linear(self, input):
"""FP8 forward for Linear layers"""
# Try fast path with fp8_linear if enabled
if self.config.get("use_fp8_matmul", False) and not self.layer.training:
try:
result = fp8_linear(self.layer, input)
if result is not None:
return result
except Exception as e:
logging.debug(f"FP8 matmul failed, falling back to standard path: {e}")
# Standard path: dequantize and compute
weight, bias = cast_bias_weight(self.layer, input)
# Dequantize weight
scale = self.layer.scale_weight.to(device=weight.device, dtype=weight.dtype)
# Apply weight functions (LoRA, etc.) - they see dequantized weights
if hasattr(self.layer, 'weight_function') and len(self.layer.weight_function) > 0:
weight = weight * scale
for f in self.layer.weight_function:
weight = f(weight)
else:
weight = weight * scale
if hasattr(self.layer, 'bias_function') and len(self.layer.bias_function) > 0:
for f in self.layer.bias_function:
bias = f(bias) if bias is not None else None
# Execute linear operation
# Optimization: multiply by scale on smaller tensor
if weight.numel() < input.numel() and len(self.layer.weight_function) == 0:
return torch.nn.functional.linear(input, weight, bias)
else:
return torch.nn.functional.linear(input, weight, bias)
def load_state_dict(self, state_dict, prefix):
"""Load scale parameters from state dict"""
scale_weight_key = f"{prefix}scale_weight"
if scale_weight_key in state_dict:
self.layer.scale_weight.data.copy_(state_dict[scale_weight_key])
scale_input_key = f"{prefix}scale_input"
if scale_input_key in state_dict and self.layer.scale_input is not None:
self.layer.scale_input.data.copy_(state_dict[scale_input_key])
def state_dict(self, prefix):
"""Save scale parameters to state dict"""
result = {f"{prefix}scale_weight": self.layer.scale_weight}
if self.layer.scale_input is not None:
result[f"{prefix}scale_input"] = self.layer.scale_input
return result
def convert_weight(self, weight, inplace=False):
"""Dequantize: multiply by scale"""
scale = self.layer.scale_weight.to(device=weight.device, dtype=weight.dtype)
if inplace:
weight *= scale
return weight
return weight * scale
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False):
"""Quantize: divide by scale with stochastic rounding"""
scale = self.layer.scale_weight.to(device=weight.device, dtype=weight.dtype)
quantized = comfy.float.stochastic_rounding(
weight / scale,
self.layer.weight.dtype,
seed=seed
)
if return_weight:
return quantized
if inplace_update:
self.layer.weight.data.copy_(quantized)
else:
self.layer.weight = torch.nn.Parameter(quantized, requires_grad=False)
# ==============================================================================
# Mixed Precision Operations
# ==============================================================================
class MixedPrecisionOps(disable_weight_init):
"""
Operations class supporting per-layer quantization (mixed precision).
This class enables different layers to use different quantization formats
within the same model (e.g., some layers FP8, others BF16).
Layer-specific quantization is configured via _layer_quant_config class variable,
which is set by pick_operations() when a model has mixed precision.
"""
_layer_quant_config = {} # Class variable set by pick_operations()
class Linear(disable_weight_init.Linear):
"""Linear layer with optional per-layer quantization"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.quant_handler = None
self._handler_initialized = False
def reset_parameters(self):
# Don't allocate weights - return None like disable_weight_init
return None
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
"""
Called by PyTorch during load_state_dict.
This is where we initialize the handler since we now know the layer name.
"""
if not self._handler_initialized:
# Normalize layer name from prefix
layer_name = prefix.rstrip('.')
# Strip known model prefixes
for model_prefix in ["model.diffusion_model.", "model.model.", "net."]:
if layer_name.startswith(model_prefix):
layer_name = layer_name[len(model_prefix):]
break
# Check if this layer has quantization config
# Access via parent class since _layer_quant_config is a class variable
if layer_name in MixedPrecisionOps._layer_quant_config:
config = MixedPrecisionOps._layer_quant_config[layer_name]
try:
format_info = QuantFormatRegistry.get(
config["format"],
**config.get("params", {})
)
# Initialize handler
self.quant_handler = format_info["handler"](self, **format_info["config"])
self.quant_handler.setup_parameters()
# Let handler load its parameters (scale_weight, etc.)
self.quant_handler.load_state_dict(state_dict, prefix)
logging.debug(f"Initialized {config['format']} handler for layer {layer_name}")
except ValueError as e:
# Format not registered - fall back to standard precision
logging.warning(
f"Quantization format '{config['format']}' not registered for layer {layer_name}. "
f"Falling back to standard precision. Error: {e}"
)
self.quant_handler = None
except Exception as e:
logging.error(f"Failed to initialize quantization handler for {layer_name}: {e}")
self.quant_handler = None
self._handler_initialized = True
# Call parent to load weight and bias
super()._load_from_state_dict(
state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs
)
def _save_to_state_dict(self, destination, prefix, keep_vars):
"""Save layer parameters including quantization metadata"""
super()._save_to_state_dict(destination, prefix, keep_vars)
# Save handler parameters (scale_weight, etc.)
if self.quant_handler:
handler_dict = self.quant_handler.state_dict(prefix)
destination.update(handler_dict)
def forward_comfy_cast_weights(self, input):
"""Forward pass with optional quantization"""
if self.quant_handler:
# Use handler for quantized forward
return self.quant_handler.forward(input)
else:
# Standard path for non-quantized layers
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
"""Main forward pass"""
run_every_op()
# Same logic as disable_weight_init.Linear
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@classmethod
def conv_nd(s, dims, *args, **kwargs):
"""Create Conv layer (same as disable_weight_init)"""
if dims == 2:
return s.Conv2d(*args, **kwargs)
elif dims == 3:
return s.Conv3d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
"""
Select appropriate operations class for model.
NEW: If model_config.layer_quant_config exists, returns MixedPrecisionOps (Phase 3).
LEGACY: All other paths unchanged for backward compatibility.
Args:
weight_dtype: Weight storage dtype
compute_dtype: Computation dtype
load_device: Device for loading
disable_fast_fp8: Disable fast FP8 paths
fp8_optimizations: Enable FP8 optimizations
scaled_fp8: Legacy FP8 dtype marker
model_config: Model config object (optional, for mixed precision support)
Returns:
Operations class (e.g., MixedPrecisionOps, fp8_ops, disable_weight_init)
"""
# NEW: Check for mixed precision
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
return MixedPrecisionOps
# LEGACY paths (unchanged)
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
@ -503,3 +953,35 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_
return disable_weight_init
return manual_cast
# ==============================================================================
# Register built-in quantization formats
# ==============================================================================
# FP8 E4M3FN weight-only quantization
QuantFormatRegistry.register(
"fp8_e4m3fn_scaled",
handler_class=FP8ScaledHandler,
base_dtype=torch.float8_e4m3fn,
quantize_activation=False,
use_fp8_matmul=True,
)
# FP8 E4M3FN weight+activation quantization
QuantFormatRegistry.register(
"fp8_e4m3fn_scaled_dynamic",
handler_class=FP8ScaledHandler,
base_dtype=torch.float8_e4m3fn,
quantize_activation=True,
use_fp8_matmul=True,
)
# FP8 E5M2 weight-only quantization
QuantFormatRegistry.register(
"fp8_e5m2_scaled",
handler_class=FP8ScaledHandler,
base_dtype=torch.float8_e5m2,
quantize_activation=False,
use_fp8_matmul=True,
)

View File

@ -50,6 +50,7 @@ class BASE:
manual_cast_dtype = None
custom_operations = None
scaled_fp8 = None
layer_quant_config = None # Per-layer quantization configuration for mixed precision
optimizations = {"fp8": False}
@classmethod