mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-15 16:02:32 +08:00
Implement mixed precision operations with a registry design and metadate for quant spec in checkpoint.
This commit is contained in:
parent
c170fd2db5
commit
70acf79346
@ -134,7 +134,7 @@ class BaseModel(torch.nn.Module):
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
if model_config.custom_operations is None:
|
||||
fp8 = model_config.optimizations.get("fp8", False)
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
|
||||
else:
|
||||
operations = model_config.custom_operations
|
||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||
@ -332,6 +332,14 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
if self.model_config.scaled_fp8 is not None:
|
||||
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
|
||||
|
||||
# Save mixed precision metadata
|
||||
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
|
||||
metadata = {
|
||||
"format_version": "1.0",
|
||||
"layers": self.model_config.layer_quant_config
|
||||
}
|
||||
unet_state_dict["_quantization_metadata"] = metadata
|
||||
|
||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||
|
||||
|
||||
@ -6,6 +6,125 @@ import math
|
||||
import logging
|
||||
import torch
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Quantization Detection Functions
|
||||
# ==============================================================================
|
||||
|
||||
def normalize_layer_name(full_key, known_prefixes):
|
||||
"""
|
||||
Strip model prefix and parameter suffix from a state dict key.
|
||||
|
||||
Args:
|
||||
full_key: Full state dict key (e.g., "model.diffusion_model.layer1.weight")
|
||||
known_prefixes: List of known model prefixes to strip
|
||||
|
||||
Returns:
|
||||
Normalized layer name (e.g., "layer1")
|
||||
"""
|
||||
name = full_key
|
||||
|
||||
# Strip model prefix
|
||||
for prefix in known_prefixes:
|
||||
if name.startswith(prefix):
|
||||
name = name[len(prefix):]
|
||||
break
|
||||
|
||||
# Remove parameter suffix
|
||||
for suffix in [".weight", ".bias", ".scale_weight", ".scale_input"]:
|
||||
if name.endswith(suffix):
|
||||
name = name[:-len(suffix)]
|
||||
break
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def detect_layer_quantization(state_dict, prefix="model.diffusion_model."):
|
||||
"""
|
||||
Detect per-layer quantization configuration from state dict.
|
||||
|
||||
Detection priority:
|
||||
1. Check for _quantization_metadata key (new format)
|
||||
2. Check for scaled_fp8 key (legacy format - return None)
|
||||
3. Check for per-layer scale_weight patterns (mixed detection)
|
||||
4. No quantization detected (return None)
|
||||
|
||||
Args:
|
||||
state_dict: Model state dictionary
|
||||
prefix: Key prefix for model layers
|
||||
|
||||
Returns:
|
||||
Dict mapping layer names to quantization configs, or None for legacy/no quantization.
|
||||
|
||||
Example return value:
|
||||
{
|
||||
"input_blocks.5.1.transformer_blocks.0.attn1.to_q": {
|
||||
"format": "fp8_e4m3fn_scaled",
|
||||
"params": {"use_fp8_matmul": True}
|
||||
},
|
||||
"middle_block.1.transformer_blocks.0.attn2.to_k": {
|
||||
"format": "fp8_e5m2_scaled",
|
||||
"params": {"use_fp8_matmul": True}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
# 1. Check for new metadata format
|
||||
metadata_key = f"{prefix}_quantization_metadata"
|
||||
if metadata_key in state_dict:
|
||||
try:
|
||||
metadata = state_dict.pop(metadata_key)
|
||||
if isinstance(metadata, dict) and "layers" in metadata:
|
||||
logging.info(f"Found quantization metadata (version {metadata.get('format_version', 'unknown')})")
|
||||
return metadata["layers"]
|
||||
else:
|
||||
logging.warning(f"Invalid quantization metadata format, ignoring")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to parse quantization metadata: {e}")
|
||||
return None
|
||||
|
||||
# 2. Check for legacy scaled_fp8 marker
|
||||
# If present, return None to use legacy code path
|
||||
scaled_fp8_key = f"{prefix}scaled_fp8"
|
||||
if scaled_fp8_key in state_dict:
|
||||
logging.debug("Detected legacy scaled_fp8 format, using legacy code path")
|
||||
return None
|
||||
|
||||
# 3. Check for per-layer scale patterns (mixed precision without metadata)
|
||||
# Look for layers that have scale_weight but not all layers have it
|
||||
known_prefixes = [prefix]
|
||||
layer_configs = {}
|
||||
layers_with_scale = set()
|
||||
layers_with_weight = set()
|
||||
|
||||
for key in state_dict.keys():
|
||||
if key.startswith(prefix):
|
||||
if key.endswith(".scale_weight"):
|
||||
layer_name = normalize_layer_name(key, known_prefixes)
|
||||
layers_with_scale.add(layer_name)
|
||||
# Detect format based on weight dtype
|
||||
weight_key = f"{prefix}{layer_name}.weight"
|
||||
if weight_key in state_dict:
|
||||
weight_dtype = state_dict[weight_key].dtype
|
||||
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
format_name = "fp8_e4m3fn_scaled" if weight_dtype == torch.float8_e4m3fn else "fp8_e5m2_scaled"
|
||||
layer_configs[layer_name] = {
|
||||
"format": format_name,
|
||||
"params": {"use_fp8_matmul": True}
|
||||
}
|
||||
elif key.endswith(".weight") and not key.endswith(".scale_weight"):
|
||||
layer_name = normalize_layer_name(key, known_prefixes)
|
||||
layers_with_weight.add(layer_name)
|
||||
|
||||
# If we found scale_weight on some but not all layers, it's mixed precision
|
||||
if layer_configs and len(layers_with_scale) < len(layers_with_weight):
|
||||
logging.info(f"Detected mixed precision via scale patterns: {len(layers_with_scale)} quantized layers, {len(layers_with_weight)} total layers")
|
||||
return layer_configs
|
||||
|
||||
# 4. No quantization detected
|
||||
return None
|
||||
|
||||
|
||||
def count_blocks(state_dict_keys, prefix_string):
|
||||
count = 0
|
||||
while True:
|
||||
@ -701,6 +820,12 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
||||
else:
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
# Detect per-layer quantization (mixed precision)
|
||||
layer_quant_config = detect_layer_quantization(state_dict, unet_key_prefix)
|
||||
if layer_quant_config:
|
||||
model_config.layer_quant_config = layer_quant_config
|
||||
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
|
||||
|
||||
return model_config
|
||||
|
||||
def unet_prefix_from_state_dict(state_dict):
|
||||
|
||||
484
comfy/ops.py
484
comfy/ops.py
@ -478,7 +478,457 @@ if CUBLAS_IS_AVAILABLE:
|
||||
def forward(self, *args, **kwargs):
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
||||
|
||||
# ==============================================================================
|
||||
# Quantization Format Registry System
|
||||
# ==============================================================================
|
||||
|
||||
class QuantFormatHandler:
|
||||
"""
|
||||
Base class for all quantization format handlers.
|
||||
|
||||
A handler encapsulates the logic for a specific quantization format
|
||||
(e.g., FP8 scaled, MX formats) and manages the quantization
|
||||
parameters and forward pass for quantized layers.
|
||||
"""
|
||||
|
||||
def __init__(self, layer, **config):
|
||||
"""
|
||||
Initialize handler for a specific layer.
|
||||
|
||||
Args:
|
||||
layer: The nn.Module layer (Linear, Conv2d, etc.)
|
||||
**config: Format-specific configuration
|
||||
"""
|
||||
self.layer = layer
|
||||
self.config = config
|
||||
|
||||
def setup_parameters(self):
|
||||
"""
|
||||
Initialize quantization parameters on the layer.
|
||||
Called during layer construction or load_state_dict.
|
||||
|
||||
Subclasses should create parameters like scale_weight, scale_input, etc.
|
||||
and attach them to self.layer.
|
||||
"""
|
||||
raise NotImplementedError(f"{self.__class__.__name__} must implement setup_parameters()")
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
"""
|
||||
Execute quantized forward pass.
|
||||
|
||||
Signature matches the layer's expected forward pass.
|
||||
Handler accesses layer parameters via self.layer (weight, bias, etc.)
|
||||
|
||||
Args:
|
||||
*args: Positional arguments matching layer forward signature
|
||||
**kwargs: Keyword arguments matching layer forward signature
|
||||
|
||||
Returns:
|
||||
Layer output tensor
|
||||
|
||||
Examples:
|
||||
Linear: forward(input)
|
||||
Conv2d: forward(input)
|
||||
GroupNorm: forward(input)
|
||||
"""
|
||||
raise NotImplementedError(f"{self.__class__.__name__} must implement forward()")
|
||||
|
||||
def load_state_dict(self, state_dict, prefix):
|
||||
"""
|
||||
Load quantization parameters from state dict.
|
||||
|
||||
Args:
|
||||
state_dict: State dictionary
|
||||
prefix: Key prefix for this layer (e.g., "model.diffusion_model.layer1.")
|
||||
"""
|
||||
raise NotImplementedError(f"{self.__class__.__name__} must implement load_state_dict()")
|
||||
|
||||
def state_dict(self, prefix):
|
||||
"""
|
||||
Save quantization parameters to state dict.
|
||||
|
||||
Args:
|
||||
prefix: Key prefix for this layer
|
||||
|
||||
Returns:
|
||||
Dictionary of quantization parameters with full keys
|
||||
"""
|
||||
raise NotImplementedError(f"{self.__class__.__name__} must implement state_dict()")
|
||||
|
||||
def convert_weight(self, weight, inplace=False):
|
||||
"""
|
||||
Convert weight from quantized to full precision (dequantize).
|
||||
|
||||
Args:
|
||||
weight: Quantized weight tensor
|
||||
inplace: Whether to modify in-place
|
||||
|
||||
Returns:
|
||||
Dequantized weight tensor
|
||||
"""
|
||||
raise NotImplementedError(f"{self.__class__.__name__} must implement convert_weight()")
|
||||
|
||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False):
|
||||
"""
|
||||
Convert and set weight from full precision to quantized.
|
||||
|
||||
Args:
|
||||
weight: Full precision weight tensor
|
||||
inplace_update: Whether to update layer weight in-place
|
||||
seed: Random seed for stochastic rounding
|
||||
return_weight: If True, return quantized weight without setting
|
||||
|
||||
Returns:
|
||||
Quantized weight if return_weight=True, else None
|
||||
"""
|
||||
raise NotImplementedError(f"{self.__class__.__name__} must implement set_weight()")
|
||||
|
||||
|
||||
class QuantFormatRegistry:
|
||||
"""
|
||||
Global registry for quantization formats.
|
||||
|
||||
Formats are registered with a unique name and handler class.
|
||||
Custom formats can be registered by custom nodes.
|
||||
"""
|
||||
|
||||
_formats = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name, handler_class, **default_config):
|
||||
"""
|
||||
Register a new quantization format.
|
||||
|
||||
Args:
|
||||
name: Unique format identifier (e.g., "fp8_e4m3fn_scaled")
|
||||
handler_class: Handler class implementing QuantFormatHandler
|
||||
**default_config: Default configuration parameters
|
||||
|
||||
Example:
|
||||
QuantFormatRegistry.register(
|
||||
"fp8_e4m3fn_scaled",
|
||||
handler_class=FP8ScaledHandler,
|
||||
base_dtype=torch.float8_e4m3fn,
|
||||
quantize_activation=False,
|
||||
use_fp8_matmul=True,
|
||||
)
|
||||
"""
|
||||
if not issubclass(handler_class, QuantFormatHandler):
|
||||
raise TypeError(f"handler_class must be a subclass of QuantFormatHandler, got {handler_class}")
|
||||
|
||||
cls._formats[name] = {
|
||||
"handler": handler_class,
|
||||
"config": default_config.copy()
|
||||
}
|
||||
logging.debug(f"Registered quantization format: {name}")
|
||||
|
||||
@classmethod
|
||||
def get(cls, name, **override_config):
|
||||
"""
|
||||
Get format info with optional config overrides.
|
||||
|
||||
Args:
|
||||
name: Format identifier
|
||||
**override_config: Configuration overrides
|
||||
|
||||
Returns:
|
||||
Dict with 'handler' (class) and 'config' (dict) keys
|
||||
|
||||
Raises:
|
||||
ValueError: If format name not registered
|
||||
"""
|
||||
if name not in cls._formats:
|
||||
available = ", ".join(cls._formats.keys()) if cls._formats else "none"
|
||||
raise ValueError(f"Unknown quantization format: '{name}'. Available formats: {available}")
|
||||
|
||||
format_info = cls._formats[name].copy()
|
||||
# Merge override_config into default config
|
||||
config = format_info["config"].copy()
|
||||
config.update(override_config)
|
||||
format_info["config"] = config
|
||||
return format_info
|
||||
|
||||
@classmethod
|
||||
def list_formats(cls):
|
||||
"""List all registered format names"""
|
||||
return list(cls._formats.keys())
|
||||
|
||||
@classmethod
|
||||
def is_registered(cls, name):
|
||||
"""Check if a format is registered"""
|
||||
return name in cls._formats
|
||||
|
||||
|
||||
class FP8ScaledHandler(QuantFormatHandler):
|
||||
"""
|
||||
Handler for FP8 quantization with per-tensor scaling.
|
||||
|
||||
Supports both weight-only and weight+activation quantization.
|
||||
Compatible with existing fp8_linear implementation.
|
||||
"""
|
||||
|
||||
def setup_parameters(self):
|
||||
"""Initialize scale_weight and optionally scale_input"""
|
||||
device = self.layer.weight.device
|
||||
dtype = torch.float32
|
||||
|
||||
# Always have scale_weight for FP8
|
||||
if not hasattr(self.layer, 'scale_weight') or self.layer.scale_weight is None:
|
||||
self.layer.scale_weight = torch.nn.Parameter(
|
||||
torch.ones((), device=device, dtype=dtype),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
# scale_input is optional (for activation quantization)
|
||||
if self.config.get("quantize_activation", False):
|
||||
if not hasattr(self.layer, 'scale_input') or self.layer.scale_input is None:
|
||||
self.layer.scale_input = torch.nn.Parameter(
|
||||
torch.ones((), device=device, dtype=dtype),
|
||||
requires_grad=False
|
||||
)
|
||||
else:
|
||||
self.layer.scale_input = None
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
"""
|
||||
FP8 forward pass with optional activation quantization.
|
||||
Supports Linear layers (Conv2d in future).
|
||||
"""
|
||||
# Detect layer type and dispatch
|
||||
if isinstance(self.layer, torch.nn.Linear):
|
||||
return self._forward_linear(*args, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"FP8ScaledHandler not implemented for {type(self.layer).__name__}"
|
||||
)
|
||||
|
||||
def _forward_linear(self, input):
|
||||
"""FP8 forward for Linear layers"""
|
||||
# Try fast path with fp8_linear if enabled
|
||||
if self.config.get("use_fp8_matmul", False) and not self.layer.training:
|
||||
try:
|
||||
result = fp8_linear(self.layer, input)
|
||||
if result is not None:
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.debug(f"FP8 matmul failed, falling back to standard path: {e}")
|
||||
|
||||
# Standard path: dequantize and compute
|
||||
weight, bias = cast_bias_weight(self.layer, input)
|
||||
|
||||
# Dequantize weight
|
||||
scale = self.layer.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||
|
||||
# Apply weight functions (LoRA, etc.) - they see dequantized weights
|
||||
if hasattr(self.layer, 'weight_function') and len(self.layer.weight_function) > 0:
|
||||
weight = weight * scale
|
||||
for f in self.layer.weight_function:
|
||||
weight = f(weight)
|
||||
else:
|
||||
weight = weight * scale
|
||||
|
||||
if hasattr(self.layer, 'bias_function') and len(self.layer.bias_function) > 0:
|
||||
for f in self.layer.bias_function:
|
||||
bias = f(bias) if bias is not None else None
|
||||
|
||||
# Execute linear operation
|
||||
# Optimization: multiply by scale on smaller tensor
|
||||
if weight.numel() < input.numel() and len(self.layer.weight_function) == 0:
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
else:
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def load_state_dict(self, state_dict, prefix):
|
||||
"""Load scale parameters from state dict"""
|
||||
scale_weight_key = f"{prefix}scale_weight"
|
||||
if scale_weight_key in state_dict:
|
||||
self.layer.scale_weight.data.copy_(state_dict[scale_weight_key])
|
||||
|
||||
scale_input_key = f"{prefix}scale_input"
|
||||
if scale_input_key in state_dict and self.layer.scale_input is not None:
|
||||
self.layer.scale_input.data.copy_(state_dict[scale_input_key])
|
||||
|
||||
def state_dict(self, prefix):
|
||||
"""Save scale parameters to state dict"""
|
||||
result = {f"{prefix}scale_weight": self.layer.scale_weight}
|
||||
if self.layer.scale_input is not None:
|
||||
result[f"{prefix}scale_input"] = self.layer.scale_input
|
||||
return result
|
||||
|
||||
def convert_weight(self, weight, inplace=False):
|
||||
"""Dequantize: multiply by scale"""
|
||||
scale = self.layer.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||
if inplace:
|
||||
weight *= scale
|
||||
return weight
|
||||
return weight * scale
|
||||
|
||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False):
|
||||
"""Quantize: divide by scale with stochastic rounding"""
|
||||
scale = self.layer.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||
quantized = comfy.float.stochastic_rounding(
|
||||
weight / scale,
|
||||
self.layer.weight.dtype,
|
||||
seed=seed
|
||||
)
|
||||
|
||||
if return_weight:
|
||||
return quantized
|
||||
|
||||
if inplace_update:
|
||||
self.layer.weight.data.copy_(quantized)
|
||||
else:
|
||||
self.layer.weight = torch.nn.Parameter(quantized, requires_grad=False)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Mixed Precision Operations
|
||||
# ==============================================================================
|
||||
|
||||
class MixedPrecisionOps(disable_weight_init):
|
||||
"""
|
||||
Operations class supporting per-layer quantization (mixed precision).
|
||||
|
||||
This class enables different layers to use different quantization formats
|
||||
within the same model (e.g., some layers FP8, others BF16).
|
||||
|
||||
Layer-specific quantization is configured via _layer_quant_config class variable,
|
||||
which is set by pick_operations() when a model has mixed precision.
|
||||
"""
|
||||
|
||||
_layer_quant_config = {} # Class variable set by pick_operations()
|
||||
|
||||
class Linear(disable_weight_init.Linear):
|
||||
"""Linear layer with optional per-layer quantization"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.quant_handler = None
|
||||
self._handler_initialized = False
|
||||
|
||||
def reset_parameters(self):
|
||||
# Don't allocate weights - return None like disable_weight_init
|
||||
return None
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
"""
|
||||
Called by PyTorch during load_state_dict.
|
||||
This is where we initialize the handler since we now know the layer name.
|
||||
"""
|
||||
if not self._handler_initialized:
|
||||
# Normalize layer name from prefix
|
||||
layer_name = prefix.rstrip('.')
|
||||
|
||||
# Strip known model prefixes
|
||||
for model_prefix in ["model.diffusion_model.", "model.model.", "net."]:
|
||||
if layer_name.startswith(model_prefix):
|
||||
layer_name = layer_name[len(model_prefix):]
|
||||
break
|
||||
|
||||
# Check if this layer has quantization config
|
||||
# Access via parent class since _layer_quant_config is a class variable
|
||||
if layer_name in MixedPrecisionOps._layer_quant_config:
|
||||
config = MixedPrecisionOps._layer_quant_config[layer_name]
|
||||
try:
|
||||
format_info = QuantFormatRegistry.get(
|
||||
config["format"],
|
||||
**config.get("params", {})
|
||||
)
|
||||
|
||||
# Initialize handler
|
||||
self.quant_handler = format_info["handler"](self, **format_info["config"])
|
||||
self.quant_handler.setup_parameters()
|
||||
|
||||
# Let handler load its parameters (scale_weight, etc.)
|
||||
self.quant_handler.load_state_dict(state_dict, prefix)
|
||||
|
||||
logging.debug(f"Initialized {config['format']} handler for layer {layer_name}")
|
||||
except ValueError as e:
|
||||
# Format not registered - fall back to standard precision
|
||||
logging.warning(
|
||||
f"Quantization format '{config['format']}' not registered for layer {layer_name}. "
|
||||
f"Falling back to standard precision. Error: {e}"
|
||||
)
|
||||
self.quant_handler = None
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to initialize quantization handler for {layer_name}: {e}")
|
||||
self.quant_handler = None
|
||||
|
||||
self._handler_initialized = True
|
||||
|
||||
# Call parent to load weight and bias
|
||||
super()._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs
|
||||
)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
"""Save layer parameters including quantization metadata"""
|
||||
super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||
|
||||
# Save handler parameters (scale_weight, etc.)
|
||||
if self.quant_handler:
|
||||
handler_dict = self.quant_handler.state_dict(prefix)
|
||||
destination.update(handler_dict)
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
"""Forward pass with optional quantization"""
|
||||
if self.quant_handler:
|
||||
# Use handler for quantized forward
|
||||
return self.quant_handler.forward(input)
|
||||
else:
|
||||
# Standard path for non-quantized layers
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
"""Main forward pass"""
|
||||
run_every_op()
|
||||
# Same logic as disable_weight_init.Linear
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def conv_nd(s, dims, *args, **kwargs):
|
||||
"""Create Conv layer (same as disable_weight_init)"""
|
||||
if dims == 2:
|
||||
return s.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return s.Conv3d(*args, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
|
||||
"""
|
||||
Select appropriate operations class for model.
|
||||
|
||||
NEW: If model_config.layer_quant_config exists, returns MixedPrecisionOps (Phase 3).
|
||||
LEGACY: All other paths unchanged for backward compatibility.
|
||||
|
||||
Args:
|
||||
weight_dtype: Weight storage dtype
|
||||
compute_dtype: Computation dtype
|
||||
load_device: Device for loading
|
||||
disable_fast_fp8: Disable fast FP8 paths
|
||||
fp8_optimizations: Enable FP8 optimizations
|
||||
scaled_fp8: Legacy FP8 dtype marker
|
||||
model_config: Model config object (optional, for mixed precision support)
|
||||
|
||||
Returns:
|
||||
Operations class (e.g., MixedPrecisionOps, fp8_ops, disable_weight_init)
|
||||
"""
|
||||
# NEW: Check for mixed precision
|
||||
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
|
||||
MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config
|
||||
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
|
||||
return MixedPrecisionOps
|
||||
|
||||
# LEGACY paths (unchanged)
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||
if scaled_fp8 is not None:
|
||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
||||
@ -503,3 +953,35 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_
|
||||
return disable_weight_init
|
||||
|
||||
return manual_cast
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Register built-in quantization formats
|
||||
# ==============================================================================
|
||||
|
||||
# FP8 E4M3FN weight-only quantization
|
||||
QuantFormatRegistry.register(
|
||||
"fp8_e4m3fn_scaled",
|
||||
handler_class=FP8ScaledHandler,
|
||||
base_dtype=torch.float8_e4m3fn,
|
||||
quantize_activation=False,
|
||||
use_fp8_matmul=True,
|
||||
)
|
||||
|
||||
# FP8 E4M3FN weight+activation quantization
|
||||
QuantFormatRegistry.register(
|
||||
"fp8_e4m3fn_scaled_dynamic",
|
||||
handler_class=FP8ScaledHandler,
|
||||
base_dtype=torch.float8_e4m3fn,
|
||||
quantize_activation=True,
|
||||
use_fp8_matmul=True,
|
||||
)
|
||||
|
||||
# FP8 E5M2 weight-only quantization
|
||||
QuantFormatRegistry.register(
|
||||
"fp8_e5m2_scaled",
|
||||
handler_class=FP8ScaledHandler,
|
||||
base_dtype=torch.float8_e5m2,
|
||||
quantize_activation=False,
|
||||
use_fp8_matmul=True,
|
||||
)
|
||||
|
||||
@ -50,6 +50,7 @@ class BASE:
|
||||
manual_cast_dtype = None
|
||||
custom_operations = None
|
||||
scaled_fp8 = None
|
||||
layer_quant_config = None # Per-layer quantization configuration for mixed precision
|
||||
optimizations = {"fp8": False}
|
||||
|
||||
@classmethod
|
||||
|
||||
Loading…
Reference in New Issue
Block a user