mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 00:12:33 +08:00
Updated design using Tensor Subclasses
This commit is contained in:
parent
70acf79346
commit
388294677e
@ -107,10 +107,10 @@ def detect_layer_quantization(state_dict, prefix="model.diffusion_model."):
|
|||||||
if weight_key in state_dict:
|
if weight_key in state_dict:
|
||||||
weight_dtype = state_dict[weight_key].dtype
|
weight_dtype = state_dict[weight_key].dtype
|
||||||
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||||
format_name = "fp8_e4m3fn_scaled" if weight_dtype == torch.float8_e4m3fn else "fp8_e5m2_scaled"
|
format_name = "fp8_e4m3fn" if weight_dtype == torch.float8_e4m3fn else "fp8_e5m2"
|
||||||
layer_configs[layer_name] = {
|
layer_configs[layer_name] = {
|
||||||
"format": format_name,
|
"format": format_name,
|
||||||
"params": {"use_fp8_matmul": True}
|
"params": {}
|
||||||
}
|
}
|
||||||
elif key.endswith(".weight") and not key.endswith(".scale_weight"):
|
elif key.endswith(".weight") and not key.endswith(".scale_weight"):
|
||||||
layer_name = normalize_layer_name(key, known_prefixes)
|
layer_name = normalize_layer_name(key, known_prefixes)
|
||||||
|
|||||||
514
comfy/ops.py
514
comfy/ops.py
@ -344,6 +344,21 @@ class manual_cast(disable_weight_init):
|
|||||||
|
|
||||||
|
|
||||||
def fp8_linear(self, input):
|
def fp8_linear(self, input):
|
||||||
|
"""
|
||||||
|
Legacy FP8 linear function - now uses tensor subclass infrastructure.
|
||||||
|
|
||||||
|
This function maintains backward compatibility with existing code while
|
||||||
|
routing all FP8 computation through the unified tensor subclass system.
|
||||||
|
All actual FP8 matmul logic is handled by the registered operation handlers
|
||||||
|
in quant_ops.py via __torch_dispatch__.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
self: Linear layer with FP8 weight and scale parameters
|
||||||
|
input: Input tensor (any dtype)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Output tensor or None if weight is not FP8
|
||||||
|
"""
|
||||||
dtype = self.weight.dtype
|
dtype = self.weight.dtype
|
||||||
if dtype not in [torch.float8_e4m3fn]:
|
if dtype not in [torch.float8_e4m3fn]:
|
||||||
return None
|
return None
|
||||||
@ -355,10 +370,12 @@ def fp8_linear(self, input):
|
|||||||
|
|
||||||
input_shape = input.shape
|
input_shape = input.shape
|
||||||
input_dtype = input.dtype
|
input_dtype = input.dtype
|
||||||
|
|
||||||
if len(input.shape) == 3:
|
if len(input.shape) == 3:
|
||||||
|
# Get weight and bias using standard casting
|
||||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
|
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
|
||||||
w = w.t()
|
|
||||||
|
|
||||||
|
# Get scales (same as before)
|
||||||
scale_weight = self.scale_weight
|
scale_weight = self.scale_weight
|
||||||
scale_input = self.scale_input
|
scale_input = self.scale_input
|
||||||
if scale_weight is None:
|
if scale_weight is None:
|
||||||
@ -368,23 +385,31 @@ def fp8_linear(self, input):
|
|||||||
|
|
||||||
if scale_input is None:
|
if scale_input is None:
|
||||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
|
||||||
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
|
|
||||||
else:
|
else:
|
||||||
scale_input = scale_input.to(input.device)
|
scale_input = scale_input.to(input.device)
|
||||||
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
|
|
||||||
|
# Wrap weight in QuantizedTensorFP8 - this enables unified dispatch
|
||||||
if bias is not None:
|
quantized_weight = QuantizedTensorFP8(w, scale_weight, orig_dtype=input_dtype)
|
||||||
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
|
||||||
|
# Handle input quantization and wrapping
|
||||||
|
if self.scale_input is None:
|
||||||
|
# Clamp input to FP8 range and quantize
|
||||||
|
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||||
|
input_fp8 = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
|
||||||
else:
|
else:
|
||||||
o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)
|
# Apply inverse scale and quantize
|
||||||
|
input_fp8 = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
|
||||||
if isinstance(o, tuple):
|
|
||||||
o = o[0]
|
# Wrap input in QuantizedTensorFP8
|
||||||
|
quantized_input = QuantizedTensorFP8(input_fp8, scale_input, orig_dtype=input_dtype)
|
||||||
|
|
||||||
|
# Call F.linear - __torch_dispatch__ routes to handle_linear_fp8 in quant_ops.py!
|
||||||
|
# This is the key unification: all FP8 computation goes through one path
|
||||||
|
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
||||||
|
|
||||||
|
# Reshape output
|
||||||
if tensor_2d:
|
if tensor_2d:
|
||||||
return o.reshape(input_shape[0], -1)
|
return o.reshape(input_shape[0], -1)
|
||||||
|
|
||||||
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
||||||
|
|
||||||
return None
|
return None
|
||||||
@ -479,307 +504,8 @@ if CUBLAS_IS_AVAILABLE:
|
|||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# Import quantization operations from separate module
|
||||||
# Quantization Format Registry System
|
from .quant_ops import QuantizedTensorFP8
|
||||||
# ==============================================================================
|
|
||||||
|
|
||||||
class QuantFormatHandler:
|
|
||||||
"""
|
|
||||||
Base class for all quantization format handlers.
|
|
||||||
|
|
||||||
A handler encapsulates the logic for a specific quantization format
|
|
||||||
(e.g., FP8 scaled, MX formats) and manages the quantization
|
|
||||||
parameters and forward pass for quantized layers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, layer, **config):
|
|
||||||
"""
|
|
||||||
Initialize handler for a specific layer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer: The nn.Module layer (Linear, Conv2d, etc.)
|
|
||||||
**config: Format-specific configuration
|
|
||||||
"""
|
|
||||||
self.layer = layer
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
def setup_parameters(self):
|
|
||||||
"""
|
|
||||||
Initialize quantization parameters on the layer.
|
|
||||||
Called during layer construction or load_state_dict.
|
|
||||||
|
|
||||||
Subclasses should create parameters like scale_weight, scale_input, etc.
|
|
||||||
and attach them to self.layer.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError(f"{self.__class__.__name__} must implement setup_parameters()")
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Execute quantized forward pass.
|
|
||||||
|
|
||||||
Signature matches the layer's expected forward pass.
|
|
||||||
Handler accesses layer parameters via self.layer (weight, bias, etc.)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
*args: Positional arguments matching layer forward signature
|
|
||||||
**kwargs: Keyword arguments matching layer forward signature
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Layer output tensor
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
Linear: forward(input)
|
|
||||||
Conv2d: forward(input)
|
|
||||||
GroupNorm: forward(input)
|
|
||||||
"""
|
|
||||||
raise NotImplementedError(f"{self.__class__.__name__} must implement forward()")
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict, prefix):
|
|
||||||
"""
|
|
||||||
Load quantization parameters from state dict.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state_dict: State dictionary
|
|
||||||
prefix: Key prefix for this layer (e.g., "model.diffusion_model.layer1.")
|
|
||||||
"""
|
|
||||||
raise NotImplementedError(f"{self.__class__.__name__} must implement load_state_dict()")
|
|
||||||
|
|
||||||
def state_dict(self, prefix):
|
|
||||||
"""
|
|
||||||
Save quantization parameters to state dict.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prefix: Key prefix for this layer
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary of quantization parameters with full keys
|
|
||||||
"""
|
|
||||||
raise NotImplementedError(f"{self.__class__.__name__} must implement state_dict()")
|
|
||||||
|
|
||||||
def convert_weight(self, weight, inplace=False):
|
|
||||||
"""
|
|
||||||
Convert weight from quantized to full precision (dequantize).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
weight: Quantized weight tensor
|
|
||||||
inplace: Whether to modify in-place
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dequantized weight tensor
|
|
||||||
"""
|
|
||||||
raise NotImplementedError(f"{self.__class__.__name__} must implement convert_weight()")
|
|
||||||
|
|
||||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False):
|
|
||||||
"""
|
|
||||||
Convert and set weight from full precision to quantized.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
weight: Full precision weight tensor
|
|
||||||
inplace_update: Whether to update layer weight in-place
|
|
||||||
seed: Random seed for stochastic rounding
|
|
||||||
return_weight: If True, return quantized weight without setting
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Quantized weight if return_weight=True, else None
|
|
||||||
"""
|
|
||||||
raise NotImplementedError(f"{self.__class__.__name__} must implement set_weight()")
|
|
||||||
|
|
||||||
|
|
||||||
class QuantFormatRegistry:
|
|
||||||
"""
|
|
||||||
Global registry for quantization formats.
|
|
||||||
|
|
||||||
Formats are registered with a unique name and handler class.
|
|
||||||
Custom formats can be registered by custom nodes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_formats = {}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register(cls, name, handler_class, **default_config):
|
|
||||||
"""
|
|
||||||
Register a new quantization format.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Unique format identifier (e.g., "fp8_e4m3fn_scaled")
|
|
||||||
handler_class: Handler class implementing QuantFormatHandler
|
|
||||||
**default_config: Default configuration parameters
|
|
||||||
|
|
||||||
Example:
|
|
||||||
QuantFormatRegistry.register(
|
|
||||||
"fp8_e4m3fn_scaled",
|
|
||||||
handler_class=FP8ScaledHandler,
|
|
||||||
base_dtype=torch.float8_e4m3fn,
|
|
||||||
quantize_activation=False,
|
|
||||||
use_fp8_matmul=True,
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
if not issubclass(handler_class, QuantFormatHandler):
|
|
||||||
raise TypeError(f"handler_class must be a subclass of QuantFormatHandler, got {handler_class}")
|
|
||||||
|
|
||||||
cls._formats[name] = {
|
|
||||||
"handler": handler_class,
|
|
||||||
"config": default_config.copy()
|
|
||||||
}
|
|
||||||
logging.debug(f"Registered quantization format: {name}")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get(cls, name, **override_config):
|
|
||||||
"""
|
|
||||||
Get format info with optional config overrides.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Format identifier
|
|
||||||
**override_config: Configuration overrides
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict with 'handler' (class) and 'config' (dict) keys
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If format name not registered
|
|
||||||
"""
|
|
||||||
if name not in cls._formats:
|
|
||||||
available = ", ".join(cls._formats.keys()) if cls._formats else "none"
|
|
||||||
raise ValueError(f"Unknown quantization format: '{name}'. Available formats: {available}")
|
|
||||||
|
|
||||||
format_info = cls._formats[name].copy()
|
|
||||||
# Merge override_config into default config
|
|
||||||
config = format_info["config"].copy()
|
|
||||||
config.update(override_config)
|
|
||||||
format_info["config"] = config
|
|
||||||
return format_info
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def list_formats(cls):
|
|
||||||
"""List all registered format names"""
|
|
||||||
return list(cls._formats.keys())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_registered(cls, name):
|
|
||||||
"""Check if a format is registered"""
|
|
||||||
return name in cls._formats
|
|
||||||
|
|
||||||
|
|
||||||
class FP8ScaledHandler(QuantFormatHandler):
|
|
||||||
"""
|
|
||||||
Handler for FP8 quantization with per-tensor scaling.
|
|
||||||
|
|
||||||
Supports both weight-only and weight+activation quantization.
|
|
||||||
Compatible with existing fp8_linear implementation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def setup_parameters(self):
|
|
||||||
"""Initialize scale_weight and optionally scale_input"""
|
|
||||||
device = self.layer.weight.device
|
|
||||||
dtype = torch.float32
|
|
||||||
|
|
||||||
# Always have scale_weight for FP8
|
|
||||||
if not hasattr(self.layer, 'scale_weight') or self.layer.scale_weight is None:
|
|
||||||
self.layer.scale_weight = torch.nn.Parameter(
|
|
||||||
torch.ones((), device=device, dtype=dtype),
|
|
||||||
requires_grad=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# scale_input is optional (for activation quantization)
|
|
||||||
if self.config.get("quantize_activation", False):
|
|
||||||
if not hasattr(self.layer, 'scale_input') or self.layer.scale_input is None:
|
|
||||||
self.layer.scale_input = torch.nn.Parameter(
|
|
||||||
torch.ones((), device=device, dtype=dtype),
|
|
||||||
requires_grad=False
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.layer.scale_input = None
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
FP8 forward pass with optional activation quantization.
|
|
||||||
Supports Linear layers (Conv2d in future).
|
|
||||||
"""
|
|
||||||
# Detect layer type and dispatch
|
|
||||||
if isinstance(self.layer, torch.nn.Linear):
|
|
||||||
return self._forward_linear(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"FP8ScaledHandler not implemented for {type(self.layer).__name__}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _forward_linear(self, input):
|
|
||||||
"""FP8 forward for Linear layers"""
|
|
||||||
# Try fast path with fp8_linear if enabled
|
|
||||||
if self.config.get("use_fp8_matmul", False) and not self.layer.training:
|
|
||||||
try:
|
|
||||||
result = fp8_linear(self.layer, input)
|
|
||||||
if result is not None:
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
logging.debug(f"FP8 matmul failed, falling back to standard path: {e}")
|
|
||||||
|
|
||||||
# Standard path: dequantize and compute
|
|
||||||
weight, bias = cast_bias_weight(self.layer, input)
|
|
||||||
|
|
||||||
# Dequantize weight
|
|
||||||
scale = self.layer.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
|
||||||
|
|
||||||
# Apply weight functions (LoRA, etc.) - they see dequantized weights
|
|
||||||
if hasattr(self.layer, 'weight_function') and len(self.layer.weight_function) > 0:
|
|
||||||
weight = weight * scale
|
|
||||||
for f in self.layer.weight_function:
|
|
||||||
weight = f(weight)
|
|
||||||
else:
|
|
||||||
weight = weight * scale
|
|
||||||
|
|
||||||
if hasattr(self.layer, 'bias_function') and len(self.layer.bias_function) > 0:
|
|
||||||
for f in self.layer.bias_function:
|
|
||||||
bias = f(bias) if bias is not None else None
|
|
||||||
|
|
||||||
# Execute linear operation
|
|
||||||
# Optimization: multiply by scale on smaller tensor
|
|
||||||
if weight.numel() < input.numel() and len(self.layer.weight_function) == 0:
|
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
|
||||||
else:
|
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict, prefix):
|
|
||||||
"""Load scale parameters from state dict"""
|
|
||||||
scale_weight_key = f"{prefix}scale_weight"
|
|
||||||
if scale_weight_key in state_dict:
|
|
||||||
self.layer.scale_weight.data.copy_(state_dict[scale_weight_key])
|
|
||||||
|
|
||||||
scale_input_key = f"{prefix}scale_input"
|
|
||||||
if scale_input_key in state_dict and self.layer.scale_input is not None:
|
|
||||||
self.layer.scale_input.data.copy_(state_dict[scale_input_key])
|
|
||||||
|
|
||||||
def state_dict(self, prefix):
|
|
||||||
"""Save scale parameters to state dict"""
|
|
||||||
result = {f"{prefix}scale_weight": self.layer.scale_weight}
|
|
||||||
if self.layer.scale_input is not None:
|
|
||||||
result[f"{prefix}scale_input"] = self.layer.scale_input
|
|
||||||
return result
|
|
||||||
|
|
||||||
def convert_weight(self, weight, inplace=False):
|
|
||||||
"""Dequantize: multiply by scale"""
|
|
||||||
scale = self.layer.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
|
||||||
if inplace:
|
|
||||||
weight *= scale
|
|
||||||
return weight
|
|
||||||
return weight * scale
|
|
||||||
|
|
||||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False):
|
|
||||||
"""Quantize: divide by scale with stochastic rounding"""
|
|
||||||
scale = self.layer.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
|
||||||
quantized = comfy.float.stochastic_rounding(
|
|
||||||
weight / scale,
|
|
||||||
self.layer.weight.dtype,
|
|
||||||
seed=seed
|
|
||||||
)
|
|
||||||
|
|
||||||
if return_weight:
|
|
||||||
return quantized
|
|
||||||
|
|
||||||
if inplace_update:
|
|
||||||
self.layer.weight.data.copy_(quantized)
|
|
||||||
else:
|
|
||||||
self.layer.weight = torch.nn.Parameter(quantized, requires_grad=False)
|
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
@ -800,12 +526,13 @@ class MixedPrecisionOps(disable_weight_init):
|
|||||||
_layer_quant_config = {} # Class variable set by pick_operations()
|
_layer_quant_config = {} # Class variable set by pick_operations()
|
||||||
|
|
||||||
class Linear(disable_weight_init.Linear):
|
class Linear(disable_weight_init.Linear):
|
||||||
"""Linear layer with optional per-layer quantization"""
|
"""Linear layer with optional per-layer quantization using tensor subclasses"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.quant_handler = None
|
self.quant_format = None
|
||||||
self._handler_initialized = False
|
self.quant_scale = None
|
||||||
|
self._quantization_initialized = False
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
# Don't allocate weights - return None like disable_weight_init
|
# Don't allocate weights - return None like disable_weight_init
|
||||||
@ -815,9 +542,16 @@ class MixedPrecisionOps(disable_weight_init):
|
|||||||
strict, missing_keys, unexpected_keys, error_msgs):
|
strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
"""
|
"""
|
||||||
Called by PyTorch during load_state_dict.
|
Called by PyTorch during load_state_dict.
|
||||||
This is where we initialize the handler since we now know the layer name.
|
Load weight and wrap in QuantizedTensorFP8 if this layer is quantized.
|
||||||
"""
|
"""
|
||||||
if not self._handler_initialized:
|
# Call parent to load weight and bias first
|
||||||
|
super()._load_from_state_dict(
|
||||||
|
state_dict, prefix, local_metadata,
|
||||||
|
strict, missing_keys, unexpected_keys, error_msgs
|
||||||
|
)
|
||||||
|
|
||||||
|
# After weight is loaded, wrap it if this layer is quantized
|
||||||
|
if not self._quantization_initialized:
|
||||||
# Normalize layer name from prefix
|
# Normalize layer name from prefix
|
||||||
layer_name = prefix.rstrip('.')
|
layer_name = prefix.rstrip('.')
|
||||||
|
|
||||||
@ -828,60 +562,78 @@ class MixedPrecisionOps(disable_weight_init):
|
|||||||
break
|
break
|
||||||
|
|
||||||
# Check if this layer has quantization config
|
# Check if this layer has quantization config
|
||||||
# Access via parent class since _layer_quant_config is a class variable
|
|
||||||
if layer_name in MixedPrecisionOps._layer_quant_config:
|
if layer_name in MixedPrecisionOps._layer_quant_config:
|
||||||
config = MixedPrecisionOps._layer_quant_config[layer_name]
|
config = MixedPrecisionOps._layer_quant_config[layer_name]
|
||||||
try:
|
self.quant_format = config.get("format", "fp8_e4m3fn")
|
||||||
format_info = QuantFormatRegistry.get(
|
|
||||||
config["format"],
|
# Load scale parameter
|
||||||
**config.get("params", {})
|
scale_key = f"{prefix}scale_weight"
|
||||||
)
|
if scale_key in state_dict:
|
||||||
|
self.quant_scale = state_dict[scale_key]
|
||||||
|
|
||||||
# Initialize handler
|
# Wrap weight in QuantizedTensorFP8
|
||||||
self.quant_handler = format_info["handler"](self, **format_info["config"])
|
if self.weight is not None and self.weight.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||||
self.quant_handler.setup_parameters()
|
try:
|
||||||
|
# Determine original dtype (default to bfloat16)
|
||||||
# Let handler load its parameters (scale_weight, etc.)
|
orig_dtype = torch.bfloat16
|
||||||
self.quant_handler.load_state_dict(state_dict, prefix)
|
|
||||||
|
# Wrap weight in quantized tensor subclass
|
||||||
logging.debug(f"Initialized {config['format']} handler for layer {layer_name}")
|
quantized_weight = QuantizedTensorFP8(
|
||||||
except ValueError as e:
|
self.weight.data,
|
||||||
# Format not registered - fall back to standard precision
|
self.quant_scale,
|
||||||
logging.warning(
|
orig_dtype=orig_dtype
|
||||||
f"Quantization format '{config['format']}' not registered for layer {layer_name}. "
|
)
|
||||||
f"Falling back to standard precision. Error: {e}"
|
|
||||||
)
|
# Replace weight parameter with wrapped version
|
||||||
self.quant_handler = None
|
self.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Failed to initialize quantization handler for {layer_name}: {e}")
|
logging.debug(f"Wrapped layer {layer_name} weight in QuantizedTensorFP8 (format: {self.quant_format})")
|
||||||
self.quant_handler = None
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed to wrap layer {layer_name} in QuantizedTensorFP8: {e}")
|
||||||
|
self.quant_format = None
|
||||||
|
self.quant_scale = None
|
||||||
|
else:
|
||||||
|
logging.debug(f"Layer {layer_name} has scale but weight dtype is not FP8, skipping quantization")
|
||||||
|
self.quant_format = None
|
||||||
|
self.quant_scale = None
|
||||||
|
else:
|
||||||
|
logging.debug(f"Layer {layer_name} has quant config but no scale_weight in state_dict")
|
||||||
|
self.quant_format = None
|
||||||
|
|
||||||
self._handler_initialized = True
|
self._quantization_initialized = True
|
||||||
|
|
||||||
# Call parent to load weight and bias
|
|
||||||
super()._load_from_state_dict(
|
|
||||||
state_dict, prefix, local_metadata,
|
|
||||||
strict, missing_keys, unexpected_keys, error_msgs
|
|
||||||
)
|
|
||||||
|
|
||||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
"""Save layer parameters including quantization metadata"""
|
"""Save layer parameters including quantization scale"""
|
||||||
super()._save_to_state_dict(destination, prefix, keep_vars)
|
# First unwrap the weight if it's quantized
|
||||||
|
if isinstance(self.weight, torch.nn.Parameter) and isinstance(self.weight.data, QuantizedTensorFP8):
|
||||||
# Save handler parameters (scale_weight, etc.)
|
# Temporarily unwrap to save the raw FP8 data
|
||||||
if self.quant_handler:
|
quantized_tensor = self.weight.data
|
||||||
handler_dict = self.quant_handler.state_dict(prefix)
|
raw_fp8_data = quantized_tensor._raw_data
|
||||||
destination.update(handler_dict)
|
original_weight = self.weight
|
||||||
|
self.weight = torch.nn.Parameter(raw_fp8_data, requires_grad=False)
|
||||||
|
|
||||||
|
# Call parent to save unwrapped weight
|
||||||
|
super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||||
|
|
||||||
|
# Restore the wrapped weight
|
||||||
|
self.weight = original_weight
|
||||||
|
|
||||||
|
# Save the scale parameter
|
||||||
|
if self.quant_scale is not None:
|
||||||
|
destination[f"{prefix}scale_weight"] = self.quant_scale if keep_vars else self.quant_scale.detach()
|
||||||
|
else:
|
||||||
|
# Standard path for non-quantized weights
|
||||||
|
super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
"""Forward pass with optional quantization"""
|
"""
|
||||||
if self.quant_handler:
|
Forward pass - tensor subclass handles dispatch automatically!
|
||||||
# Use handler for quantized forward
|
__torch_dispatch__ will route to registered handlers based on tensor types.
|
||||||
return self.quant_handler.forward(input)
|
"""
|
||||||
else:
|
weight, bias = cast_bias_weight(self, input)
|
||||||
# Standard path for non-quantized layers
|
|
||||||
weight, bias = cast_bias_weight(self, input)
|
# Call F.linear - if weight is QuantizedTensorFP8, __torch_dispatch__ handles it!
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
"""Main forward pass"""
|
"""Main forward pass"""
|
||||||
@ -953,35 +705,3 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_
|
|||||||
return disable_weight_init
|
return disable_weight_init
|
||||||
|
|
||||||
return manual_cast
|
return manual_cast
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
# Register built-in quantization formats
|
|
||||||
# ==============================================================================
|
|
||||||
|
|
||||||
# FP8 E4M3FN weight-only quantization
|
|
||||||
QuantFormatRegistry.register(
|
|
||||||
"fp8_e4m3fn_scaled",
|
|
||||||
handler_class=FP8ScaledHandler,
|
|
||||||
base_dtype=torch.float8_e4m3fn,
|
|
||||||
quantize_activation=False,
|
|
||||||
use_fp8_matmul=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# FP8 E4M3FN weight+activation quantization
|
|
||||||
QuantFormatRegistry.register(
|
|
||||||
"fp8_e4m3fn_scaled_dynamic",
|
|
||||||
handler_class=FP8ScaledHandler,
|
|
||||||
base_dtype=torch.float8_e4m3fn,
|
|
||||||
quantize_activation=True,
|
|
||||||
use_fp8_matmul=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# FP8 E5M2 weight-only quantization
|
|
||||||
QuantFormatRegistry.register(
|
|
||||||
"fp8_e5m2_scaled",
|
|
||||||
handler_class=FP8ScaledHandler,
|
|
||||||
base_dtype=torch.float8_e5m2,
|
|
||||||
quantize_activation=False,
|
|
||||||
use_fp8_matmul=True,
|
|
||||||
)
|
|
||||||
|
|||||||
346
comfy/quant_ops.py
Normal file
346
comfy/quant_ops.py
Normal file
@ -0,0 +1,346 @@
|
|||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Global Operation Registry
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
# Global operation registry: torch operation → handler function
|
||||||
|
_QUANT_OP_REGISTRY = {}
|
||||||
|
|
||||||
|
def register_quant_op(torch_op):
|
||||||
|
"""
|
||||||
|
Decorator to register an operation handler.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@register_quant_op(torch.ops.aten.linear.default)
|
||||||
|
def handle_linear_fp8(func, args, kwargs):
|
||||||
|
# Implementation
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
def decorator(handler_func):
|
||||||
|
_QUANT_OP_REGISTRY[torch_op] = handler_func
|
||||||
|
return handler_func
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def get_quant_handler(torch_op):
|
||||||
|
"""Get registered handler for an operation"""
|
||||||
|
return _QUANT_OP_REGISTRY.get(torch_op)
|
||||||
|
|
||||||
|
|
||||||
|
def list_registered_ops():
|
||||||
|
"""List all registered quantized operations"""
|
||||||
|
return list(_QUANT_OP_REGISTRY.keys())
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# comfy_kitchen Integration
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
try:
|
||||||
|
import comfy_kitchen as ck
|
||||||
|
ck.disable_backend("cutile")
|
||||||
|
_CK_AVAILABLE = True
|
||||||
|
logging.info("comfy_kitchen available for optimized quantization kernels")
|
||||||
|
except ImportError:
|
||||||
|
ck = None
|
||||||
|
_CK_AVAILABLE = False
|
||||||
|
logging.info("comfy_kitchen not available - using PyTorch fallbacks")
|
||||||
|
except Exception as e:
|
||||||
|
ck = None
|
||||||
|
_CK_AVAILABLE = False
|
||||||
|
logging.warning(f"comfy_kitchen import failed: {e} - using PyTorch fallbacks")
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Quantized Tensor Subclass
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class QuantizedTensorFP8(torch.Tensor):
|
||||||
|
"""
|
||||||
|
Tensor subclass for FP8 quantized data.
|
||||||
|
Automatically handles operations via __torch_dispatch__.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __new__(cls, tensor, scale, orig_dtype=torch.bfloat16):
|
||||||
|
"""
|
||||||
|
Create a quantized FP8 tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: The FP8 tensor data (torch.float8_e4m3fn or e5m2)
|
||||||
|
scale: Scale factor for dequantization (scalar tensor)
|
||||||
|
orig_dtype: Original dtype before quantization
|
||||||
|
"""
|
||||||
|
return torch.Tensor._make_subclass(cls, tensor, require_grad=False)
|
||||||
|
|
||||||
|
def __init__(self, tensor, scale, orig_dtype=torch.bfloat16):
|
||||||
|
self._scale = scale
|
||||||
|
self._orig_dtype = orig_dtype
|
||||||
|
# Store a reference to prevent infinite recursion in dequantize
|
||||||
|
self._raw_data = tensor
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (f"QuantizedTensorFP8(shape={self.shape}, "
|
||||||
|
f"scale={self._scale:.4f}, dtype={self._orig_dtype})")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||||
|
"""
|
||||||
|
Intercept ALL torch operations.
|
||||||
|
Routes to registered handlers or falls back to dequantization.
|
||||||
|
"""
|
||||||
|
kwargs = kwargs or {}
|
||||||
|
|
||||||
|
# Special case: skip dispatch for internal tensor operations
|
||||||
|
# that are used for unwrapping (to avoid recursion)
|
||||||
|
if func in [torch.ops.aten._to_copy.default, torch.ops.aten.detach.default]:
|
||||||
|
# For these ops, use the raw data to avoid recursion, but return QuantizedTensorFP8 for detach
|
||||||
|
if func == torch.ops.aten.detach.default and isinstance(args[0], QuantizedTensorFP8):
|
||||||
|
# Special handling for detach - return a new QuantizedTensorFP8
|
||||||
|
qt = args[0]
|
||||||
|
detached_data = qt._raw_data.detach()
|
||||||
|
return QuantizedTensorFP8(detached_data, qt._scale, qt._orig_dtype)
|
||||||
|
|
||||||
|
# For other ops, just unwrap
|
||||||
|
def unwrap(arg):
|
||||||
|
if isinstance(arg, QuantizedTensorFP8):
|
||||||
|
return arg._raw_data
|
||||||
|
return arg
|
||||||
|
new_args = tuple(unwrap(a) if not isinstance(a, (list, tuple, dict)) else a for a in args)
|
||||||
|
return func(*new_args, **kwargs)
|
||||||
|
|
||||||
|
# Look up registered handler for this operation
|
||||||
|
handler = _QUANT_OP_REGISTRY.get(func)
|
||||||
|
if handler:
|
||||||
|
return handler(func, args, kwargs)
|
||||||
|
|
||||||
|
# No handler - dequantize and use standard path
|
||||||
|
return cls._dequant_and_fallback(func, args, kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _dequant_and_fallback(cls, func, args, kwargs):
|
||||||
|
"""Fallback: dequantize all quantized tensors"""
|
||||||
|
def dequant_arg(arg):
|
||||||
|
if isinstance(arg, QuantizedTensorFP8):
|
||||||
|
return arg.dequantize()
|
||||||
|
elif isinstance(arg, (list, tuple)):
|
||||||
|
return type(arg)(dequant_arg(a) for a in arg)
|
||||||
|
return arg
|
||||||
|
|
||||||
|
new_args = dequant_arg(args)
|
||||||
|
new_kwargs = dequant_arg(kwargs)
|
||||||
|
return func(*new_args, **new_kwargs)
|
||||||
|
|
||||||
|
def dequantize(self) -> torch.Tensor:
|
||||||
|
"""Explicit dequantization"""
|
||||||
|
# Use the raw data and convert directly
|
||||||
|
# Call aten ops directly to minimize dispatch interference
|
||||||
|
plain_tensor = torch.ops.aten._to_copy.default(self._raw_data, dtype=self._orig_dtype)
|
||||||
|
# Multiply by scale
|
||||||
|
return plain_tensor * self._scale
|
||||||
|
|
||||||
|
def detach(self):
|
||||||
|
"""Detach returns a new QuantizedTensorFP8 (required for Parameter)"""
|
||||||
|
# Detach the raw data and create a new QuantizedTensorFP8
|
||||||
|
detached_data = self._raw_data.detach()
|
||||||
|
return QuantizedTensorFP8(detached_data, self._scale, self._orig_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Operation Handlers for Quantized Tensors
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
@register_quant_op(torch.ops.aten.linear.default)
|
||||||
|
def handle_linear_fp8(func, args, kwargs):
|
||||||
|
"""
|
||||||
|
Handle F.linear() with quantized inputs.
|
||||||
|
|
||||||
|
Supports:
|
||||||
|
- QuantizedTensorFP8 input + QuantizedTensorFP8 weight
|
||||||
|
- QuantizedTensorFP8 input + regular weight
|
||||||
|
- Regular input + QuantizedTensorFP8 weight
|
||||||
|
"""
|
||||||
|
input_tensor = args[0]
|
||||||
|
weight = args[1]
|
||||||
|
bias = args[2] if len(args) > 2 else None
|
||||||
|
|
||||||
|
# Case 1: Both input and weight are FP8
|
||||||
|
if isinstance(input_tensor, QuantizedTensorFP8) and isinstance(weight, QuantizedTensorFP8):
|
||||||
|
# Use _scaled_mm for FP8×FP8 matmul
|
||||||
|
# Get plain tensors to avoid dispatch recursion
|
||||||
|
plain_input = input_tensor._raw_data
|
||||||
|
plain_weight = weight._raw_data
|
||||||
|
weight_t = plain_weight.t().contiguous()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if bias is not None:
|
||||||
|
output = torch._scaled_mm(
|
||||||
|
plain_input,
|
||||||
|
weight_t,
|
||||||
|
out_dtype=input_tensor._orig_dtype,
|
||||||
|
bias=bias,
|
||||||
|
scale_a=input_tensor._scale,
|
||||||
|
scale_b=weight._scale
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output = torch._scaled_mm(
|
||||||
|
plain_input,
|
||||||
|
weight_t,
|
||||||
|
out_dtype=input_tensor._orig_dtype,
|
||||||
|
scale_a=input_tensor._scale,
|
||||||
|
scale_b=weight._scale
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(output, tuple):
|
||||||
|
output = output[0]
|
||||||
|
|
||||||
|
# Check if output is FP8 (some architectures support this)
|
||||||
|
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||||
|
# Keep quantized!
|
||||||
|
output_scale = input_tensor._scale * weight._scale
|
||||||
|
return QuantizedTensorFP8(output, output_scale, input_tensor._orig_dtype)
|
||||||
|
else:
|
||||||
|
return output
|
||||||
|
except Exception as e:
|
||||||
|
logging.debug(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
|
||||||
|
# Fall through to dequantization path
|
||||||
|
|
||||||
|
# Case 2: Only weight is quantized
|
||||||
|
if isinstance(weight, QuantizedTensorFP8):
|
||||||
|
weight_dq = weight.dequantize()
|
||||||
|
input_dq = input_tensor.dequantize() if isinstance(input_tensor, QuantizedTensorFP8) else input_tensor
|
||||||
|
return torch.nn.functional.linear(input_dq, weight_dq, bias)
|
||||||
|
|
||||||
|
# Case 3: Only input is quantized
|
||||||
|
elif isinstance(input_tensor, QuantizedTensorFP8):
|
||||||
|
input_dq = input_tensor.dequantize()
|
||||||
|
return torch.nn.functional.linear(input_dq, weight, bias)
|
||||||
|
|
||||||
|
# Case 4: Neither is quantized (shouldn't happen, but handle it)
|
||||||
|
else:
|
||||||
|
return torch.nn.functional.linear(input_tensor, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
@register_quant_op(torch.ops.aten.silu.default)
|
||||||
|
def handle_silu_fp8(func, args, kwargs):
|
||||||
|
"""
|
||||||
|
SiLU can be computed approximately on FP8.
|
||||||
|
Keeps activations quantized for next layer.
|
||||||
|
"""
|
||||||
|
input_q = args[0]
|
||||||
|
|
||||||
|
if not isinstance(input_q, QuantizedTensorFP8):
|
||||||
|
# Not quantized, use standard path
|
||||||
|
return torch.nn.functional.silu(input_q)
|
||||||
|
|
||||||
|
# Compute SiLU while keeping quantized
|
||||||
|
# SiLU(x) = x * sigmoid(x)
|
||||||
|
|
||||||
|
# Get plain tensor to avoid dispatch recursion
|
||||||
|
plain_tensor = input_q._raw_data
|
||||||
|
|
||||||
|
# Upcast to FP16 for sigmoid stability
|
||||||
|
x_fp16 = plain_tensor.to(torch.float16)
|
||||||
|
sigmoid_fp16 = torch.sigmoid(x_fp16 * input_q._scale)
|
||||||
|
result_fp16 = x_fp16 * sigmoid_fp16
|
||||||
|
|
||||||
|
# Convert back to FP8
|
||||||
|
result_fp8 = result_fp16.to(plain_tensor.dtype)
|
||||||
|
|
||||||
|
# Return quantized (scale approximately preserved)
|
||||||
|
return QuantizedTensorFP8(result_fp8, input_q._scale, input_q._orig_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@register_quant_op(torch.ops.aten.layer_norm.default)
|
||||||
|
def handle_layernorm_fp8(func, args, kwargs):
|
||||||
|
"""
|
||||||
|
LayerNorm requires high precision.
|
||||||
|
Dequantizes input and returns standard tensor.
|
||||||
|
"""
|
||||||
|
input_q = args[0]
|
||||||
|
normalized_shape = args[1]
|
||||||
|
weight = args[2] if len(args) > 2 else None
|
||||||
|
bias = args[3] if len(args) > 3 else None
|
||||||
|
eps = args[4] if len(args) > 4 else 1e-5
|
||||||
|
|
||||||
|
# Dequantize if needed
|
||||||
|
if isinstance(input_q, QuantizedTensorFP8):
|
||||||
|
x = input_q.dequantize()
|
||||||
|
else:
|
||||||
|
x = input_q
|
||||||
|
|
||||||
|
# Standard LayerNorm
|
||||||
|
result = torch.nn.functional.layer_norm(x, normalized_shape, weight, bias, eps)
|
||||||
|
|
||||||
|
# Return dequantized (next layer will quantize if needed)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@register_quant_op(torch.ops.aten.group_norm.default)
|
||||||
|
def handle_groupnorm_fp8(func, args, kwargs):
|
||||||
|
"""
|
||||||
|
GroupNorm requires high precision.
|
||||||
|
Dequantizes input and returns standard tensor.
|
||||||
|
"""
|
||||||
|
input_q = args[0]
|
||||||
|
num_groups = args[1]
|
||||||
|
weight = args[2] if len(args) > 2 else None
|
||||||
|
bias = args[3] if len(args) > 3 else None
|
||||||
|
eps = args[4] if len(args) > 4 else 1e-5
|
||||||
|
|
||||||
|
# Dequantize if needed
|
||||||
|
if isinstance(input_q, QuantizedTensorFP8):
|
||||||
|
x = input_q.dequantize()
|
||||||
|
else:
|
||||||
|
x = input_q
|
||||||
|
|
||||||
|
# Standard GroupNorm
|
||||||
|
result = torch.nn.functional.group_norm(x, num_groups, weight, bias, eps)
|
||||||
|
|
||||||
|
# Return dequantized
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@register_quant_op(torch.ops.aten.add.Tensor)
|
||||||
|
def handle_add_fp8(func, args, kwargs):
|
||||||
|
"""
|
||||||
|
Handle addition with mixed quantized/non-quantized tensors.
|
||||||
|
"""
|
||||||
|
a = args[0]
|
||||||
|
b = args[1]
|
||||||
|
|
||||||
|
# If both are quantized, dequantize both
|
||||||
|
if isinstance(a, QuantizedTensorFP8) and isinstance(b, QuantizedTensorFP8):
|
||||||
|
return a.dequantize() + b.dequantize()
|
||||||
|
# If only one is quantized, dequantize it
|
||||||
|
elif isinstance(a, QuantizedTensorFP8):
|
||||||
|
return a.dequantize() + b
|
||||||
|
elif isinstance(b, QuantizedTensorFP8):
|
||||||
|
return a + b.dequantize()
|
||||||
|
# Neither is quantized
|
||||||
|
else:
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
|
||||||
|
@register_quant_op(torch.ops.aten.mul.Tensor)
|
||||||
|
def handle_mul_fp8(func, args, kwargs):
|
||||||
|
"""
|
||||||
|
Handle multiplication with mixed quantized/non-quantized tensors.
|
||||||
|
"""
|
||||||
|
a = args[0]
|
||||||
|
b = args[1]
|
||||||
|
|
||||||
|
# If both are quantized, dequantize both
|
||||||
|
if isinstance(a, QuantizedTensorFP8) and isinstance(b, QuantizedTensorFP8):
|
||||||
|
return a.dequantize() * b.dequantize()
|
||||||
|
# If only one is quantized, dequantize it
|
||||||
|
elif isinstance(a, QuantizedTensorFP8):
|
||||||
|
return a.dequantize() * b
|
||||||
|
elif isinstance(b, QuantizedTensorFP8):
|
||||||
|
return a * b.dequantize()
|
||||||
|
# Neither is quantized
|
||||||
|
else:
|
||||||
|
return a * b
|
||||||
|
|
||||||
274
tests-unit/comfy_test/test_mixed_precision.py
Normal file
274
tests-unit/comfy_test/test_mixed_precision.py
Normal file
@ -0,0 +1,274 @@
|
|||||||
|
"""
|
||||||
|
End-to-end tests for mixed precision quantization.
|
||||||
|
Tests Phase 3: Mixed Precision Operations
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import torch
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Add comfy to path
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||||
|
|
||||||
|
from comfy import ops
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleModel(torch.nn.Module):
|
||||||
|
"""Simple model for testing mixed precision"""
|
||||||
|
def __init__(self, operations=ops.disable_weight_init):
|
||||||
|
super().__init__()
|
||||||
|
self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16)
|
||||||
|
self.layer2 = operations.Linear(20, 30, device="cpu", dtype=torch.bfloat16)
|
||||||
|
self.layer3 = operations.Linear(30, 40, device="cpu", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = torch.nn.functional.relu(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = torch.nn.functional.relu(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TestMixedPrecisionOps(unittest.TestCase):
|
||||||
|
"""Test MixedPrecisionOps end-to-end"""
|
||||||
|
|
||||||
|
def test_all_layers_standard(self):
|
||||||
|
"""Test that model with no quantization works normally"""
|
||||||
|
# Configure no quantization
|
||||||
|
ops.MixedPrecisionOps._layer_quant_config = {}
|
||||||
|
|
||||||
|
# Create model
|
||||||
|
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||||
|
|
||||||
|
# Initialize weights manually
|
||||||
|
model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16))
|
||||||
|
model.layer1.bias = torch.nn.Parameter(torch.randn(20, dtype=torch.bfloat16))
|
||||||
|
model.layer2.weight = torch.nn.Parameter(torch.randn(30, 20, dtype=torch.bfloat16))
|
||||||
|
model.layer2.bias = torch.nn.Parameter(torch.randn(30, dtype=torch.bfloat16))
|
||||||
|
model.layer3.weight = torch.nn.Parameter(torch.randn(40, 30, dtype=torch.bfloat16))
|
||||||
|
model.layer3.bias = torch.nn.Parameter(torch.randn(40, dtype=torch.bfloat16))
|
||||||
|
|
||||||
|
# Initialize weight_function and bias_function
|
||||||
|
for layer in [model.layer1, model.layer2, model.layer3]:
|
||||||
|
layer.weight_function = []
|
||||||
|
layer.bias_function = []
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||||
|
output = model(input_tensor)
|
||||||
|
|
||||||
|
self.assertEqual(output.shape, (5, 40))
|
||||||
|
self.assertEqual(output.dtype, torch.bfloat16)
|
||||||
|
|
||||||
|
def test_mixed_precision_load(self):
|
||||||
|
"""Test loading a mixed precision model from state dict"""
|
||||||
|
# Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard
|
||||||
|
layer_quant_config = {
|
||||||
|
"layer1": {
|
||||||
|
"format": "fp8_e4m3fn_scaled",
|
||||||
|
"params": {"use_fp8_matmul": False} # Disable for CPU testing
|
||||||
|
},
|
||||||
|
"layer3": {
|
||||||
|
"format": "fp8_e5m2_scaled",
|
||||||
|
"params": {"use_fp8_matmul": False}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||||
|
|
||||||
|
# Create state dict with mixed precision
|
||||||
|
fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e5m2)
|
||||||
|
|
||||||
|
state_dict = {
|
||||||
|
# Layer 1: FP8 E4M3FN
|
||||||
|
"layer1.weight": fp8_weight1,
|
||||||
|
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
|
||||||
|
"layer1.scale_weight": torch.tensor(2.0, dtype=torch.float32),
|
||||||
|
|
||||||
|
# Layer 2: Standard BF16
|
||||||
|
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
|
||||||
|
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
|
||||||
|
|
||||||
|
# Layer 3: FP8 E5M2
|
||||||
|
"layer3.weight": fp8_weight3,
|
||||||
|
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||||
|
"layer3.scale_weight": torch.tensor(1.5, dtype=torch.float32),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create model and load state dict
|
||||||
|
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
# Verify handlers are set up correctly
|
||||||
|
self.assertIsNotNone(model.layer1.quant_handler)
|
||||||
|
self.assertIsNone(model.layer2.quant_handler) # No quantization
|
||||||
|
self.assertIsNotNone(model.layer3.quant_handler)
|
||||||
|
|
||||||
|
# Verify scales were loaded
|
||||||
|
self.assertEqual(model.layer1.scale_weight.item(), 2.0)
|
||||||
|
self.assertEqual(model.layer3.scale_weight.item(), 1.5)
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||||
|
output = model(input_tensor)
|
||||||
|
|
||||||
|
self.assertEqual(output.shape, (5, 40))
|
||||||
|
|
||||||
|
def test_state_dict_round_trip(self):
|
||||||
|
"""Test saving and loading state dict preserves quantization"""
|
||||||
|
# Configure mixed precision
|
||||||
|
layer_quant_config = {
|
||||||
|
"layer1": {
|
||||||
|
"format": "fp8_e4m3fn_scaled",
|
||||||
|
"params": {"use_fp8_matmul": False}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||||
|
|
||||||
|
# Create and load model
|
||||||
|
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
state_dict1 = {
|
||||||
|
"layer1.weight": fp8_weight,
|
||||||
|
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
|
||||||
|
"layer1.scale_weight": torch.tensor(3.0, dtype=torch.float32),
|
||||||
|
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
|
||||||
|
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
|
||||||
|
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
|
||||||
|
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||||
|
}
|
||||||
|
|
||||||
|
model1 = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||||
|
model1.load_state_dict(state_dict1)
|
||||||
|
|
||||||
|
# Save state dict
|
||||||
|
state_dict2 = model1.state_dict()
|
||||||
|
|
||||||
|
# Verify scale_weight is saved
|
||||||
|
self.assertIn("layer1.scale_weight", state_dict2)
|
||||||
|
self.assertEqual(state_dict2["layer1.scale_weight"].item(), 3.0)
|
||||||
|
|
||||||
|
# Load into new model
|
||||||
|
model2 = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||||
|
model2.load_state_dict(state_dict2)
|
||||||
|
|
||||||
|
# Verify handler is set up
|
||||||
|
self.assertIsNotNone(model2.layer1.quant_handler)
|
||||||
|
self.assertEqual(model2.layer1.scale_weight.item(), 3.0)
|
||||||
|
|
||||||
|
# Verify forward passes match
|
||||||
|
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||||
|
output1 = model1(input_tensor)
|
||||||
|
output2 = model2(input_tensor)
|
||||||
|
|
||||||
|
torch.testing.assert_close(output1, output2, rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
|
def test_weight_function_compatibility(self):
|
||||||
|
"""Test that weight_function (LoRA) works with quantized layers"""
|
||||||
|
# Configure FP8 quantization
|
||||||
|
layer_quant_config = {
|
||||||
|
"layer1": {
|
||||||
|
"format": "fp8_e4m3fn_scaled",
|
||||||
|
"params": {"use_fp8_matmul": False}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||||
|
|
||||||
|
# Create and load model
|
||||||
|
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
state_dict = {
|
||||||
|
"layer1.weight": fp8_weight,
|
||||||
|
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
|
||||||
|
"layer1.scale_weight": torch.tensor(2.0, dtype=torch.float32),
|
||||||
|
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
|
||||||
|
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
|
||||||
|
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
|
||||||
|
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||||
|
}
|
||||||
|
|
||||||
|
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
# Add a weight function (simulating LoRA)
|
||||||
|
# LoRA delta must match weight shape (20, 10)
|
||||||
|
def apply_lora(weight):
|
||||||
|
# Generate LoRA delta matching weight shape
|
||||||
|
lora_delta = torch.randn_like(weight) * 0.01
|
||||||
|
return weight + lora_delta
|
||||||
|
|
||||||
|
model.layer1.weight_function.append(apply_lora)
|
||||||
|
|
||||||
|
# Forward pass should work with LoRA
|
||||||
|
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||||
|
output = model(input_tensor)
|
||||||
|
|
||||||
|
self.assertEqual(output.shape, (5, 40))
|
||||||
|
|
||||||
|
def test_error_handling_unknown_format(self):
|
||||||
|
"""Test that unknown formats fall back gracefully"""
|
||||||
|
# Configure with unknown format
|
||||||
|
layer_quant_config = {
|
||||||
|
"layer1": {
|
||||||
|
"format": "unknown_format_xyz",
|
||||||
|
"params": {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||||
|
|
||||||
|
# Create state dict
|
||||||
|
state_dict = {
|
||||||
|
"layer1.weight": torch.randn(20, 10, dtype=torch.bfloat16),
|
||||||
|
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
|
||||||
|
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
|
||||||
|
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
|
||||||
|
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
|
||||||
|
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Load should not crash, just log warning
|
||||||
|
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
# Handler should be None (fallback to standard)
|
||||||
|
self.assertIsNone(model.layer1.quant_handler)
|
||||||
|
|
||||||
|
# Forward pass should still work
|
||||||
|
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||||
|
output = model(input_tensor)
|
||||||
|
self.assertEqual(output.shape, (5, 40))
|
||||||
|
|
||||||
|
|
||||||
|
class TestPickOperationsWithMixedPrecision(unittest.TestCase):
|
||||||
|
"""Test pick_operations with mixed precision config"""
|
||||||
|
|
||||||
|
def test_pick_operations_with_layer_quant_config(self):
|
||||||
|
"""Test that pick_operations returns MixedPrecisionOps when config present"""
|
||||||
|
from comfy import supported_models_base
|
||||||
|
|
||||||
|
# Create model config with layer_quant_config
|
||||||
|
model_config = supported_models_base.BASE({})
|
||||||
|
model_config.layer_quant_config = {
|
||||||
|
"layer1": {"format": "fp8_e4m3fn_scaled", "params": {}}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = ops.pick_operations(None, None, model_config=model_config)
|
||||||
|
|
||||||
|
self.assertEqual(result, ops.MixedPrecisionOps)
|
||||||
|
self.assertEqual(ops.MixedPrecisionOps._layer_quant_config, model_config.layer_quant_config)
|
||||||
|
|
||||||
|
def test_pick_operations_without_layer_quant_config(self):
|
||||||
|
"""Test that pick_operations falls back to standard when no config"""
|
||||||
|
from comfy import supported_models_base
|
||||||
|
|
||||||
|
model_config = supported_models_base.BASE({})
|
||||||
|
model_config.layer_quant_config = None
|
||||||
|
|
||||||
|
result = ops.pick_operations(None, None, model_config=model_config)
|
||||||
|
|
||||||
|
self.assertEqual(result, ops.disable_weight_init)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
|
|
||||||
262
tests-unit/comfy_test/test_quant_detection.py
Normal file
262
tests-unit/comfy_test/test_quant_detection.py
Normal file
@ -0,0 +1,262 @@
|
|||||||
|
"""
|
||||||
|
Integration tests for quantization detection.
|
||||||
|
Tests Phase 2: Detection & Integration
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import torch
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Add comfy to path
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||||
|
|
||||||
|
from comfy import model_detection
|
||||||
|
|
||||||
|
|
||||||
|
class TestNormalizeLayerName(unittest.TestCase):
|
||||||
|
"""Test the normalize_layer_name helper function"""
|
||||||
|
|
||||||
|
def test_strip_prefix_and_suffix(self):
|
||||||
|
"""Test stripping prefix and suffix"""
|
||||||
|
known_prefixes = ["model.diffusion_model."]
|
||||||
|
result = model_detection.normalize_layer_name(
|
||||||
|
"model.diffusion_model.layer1.weight",
|
||||||
|
known_prefixes
|
||||||
|
)
|
||||||
|
self.assertEqual(result, "layer1")
|
||||||
|
|
||||||
|
def test_strip_multiple_prefixes(self):
|
||||||
|
"""Test with multiple known prefixes"""
|
||||||
|
known_prefixes = ["model.diffusion_model.", "model.model.", "net."]
|
||||||
|
|
||||||
|
result1 = model_detection.normalize_layer_name(
|
||||||
|
"model.diffusion_model.block.attn.weight",
|
||||||
|
known_prefixes
|
||||||
|
)
|
||||||
|
self.assertEqual(result1, "block.attn")
|
||||||
|
|
||||||
|
result2 = model_detection.normalize_layer_name(
|
||||||
|
"model.model.encoder.layer.weight",
|
||||||
|
known_prefixes
|
||||||
|
)
|
||||||
|
self.assertEqual(result2, "encoder.layer")
|
||||||
|
|
||||||
|
result3 = model_detection.normalize_layer_name(
|
||||||
|
"net.transformer.blocks.0.weight",
|
||||||
|
known_prefixes
|
||||||
|
)
|
||||||
|
self.assertEqual(result3, "transformer.blocks.0")
|
||||||
|
|
||||||
|
def test_strip_scale_weight_suffix(self):
|
||||||
|
"""Test stripping scale_weight suffix"""
|
||||||
|
known_prefixes = ["model.diffusion_model."]
|
||||||
|
result = model_detection.normalize_layer_name(
|
||||||
|
"model.diffusion_model.layer1.scale_weight",
|
||||||
|
known_prefixes
|
||||||
|
)
|
||||||
|
self.assertEqual(result, "layer1")
|
||||||
|
|
||||||
|
def test_strip_bias_suffix(self):
|
||||||
|
"""Test stripping bias suffix"""
|
||||||
|
known_prefixes = ["model.diffusion_model."]
|
||||||
|
result = model_detection.normalize_layer_name(
|
||||||
|
"model.diffusion_model.layer1.bias",
|
||||||
|
known_prefixes
|
||||||
|
)
|
||||||
|
self.assertEqual(result, "layer1")
|
||||||
|
|
||||||
|
def test_no_prefix_match(self):
|
||||||
|
"""Test with no prefix match"""
|
||||||
|
known_prefixes = ["model.diffusion_model."]
|
||||||
|
result = model_detection.normalize_layer_name(
|
||||||
|
"other.model.layer1.weight",
|
||||||
|
known_prefixes
|
||||||
|
)
|
||||||
|
# Should strip suffix but not prefix
|
||||||
|
self.assertEqual(result, "other.model.layer1")
|
||||||
|
|
||||||
|
|
||||||
|
class TestDetectLayerQuantization(unittest.TestCase):
|
||||||
|
"""Test the detect_layer_quantization function"""
|
||||||
|
|
||||||
|
def test_no_quantization(self):
|
||||||
|
"""Test with no quantization markers"""
|
||||||
|
state_dict = {
|
||||||
|
"model.diffusion_model.layer1.weight": torch.randn(10, 20),
|
||||||
|
"model.diffusion_model.layer2.weight": torch.randn(20, 30),
|
||||||
|
}
|
||||||
|
result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.")
|
||||||
|
self.assertIsNone(result)
|
||||||
|
|
||||||
|
def test_legacy_scaled_fp8(self):
|
||||||
|
"""Test that legacy scaled_fp8 marker returns None"""
|
||||||
|
# Create FP8 tensor by converting from float32
|
||||||
|
fp8_weight = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
state_dict = {
|
||||||
|
"model.diffusion_model.scaled_fp8": torch.tensor([], dtype=torch.float8_e4m3fn),
|
||||||
|
"model.diffusion_model.layer1.weight": fp8_weight,
|
||||||
|
"model.diffusion_model.layer1.scale_weight": torch.tensor(1.0),
|
||||||
|
}
|
||||||
|
result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.")
|
||||||
|
# Should return None to trigger legacy path
|
||||||
|
self.assertIsNone(result)
|
||||||
|
|
||||||
|
def test_metadata_format(self):
|
||||||
|
"""Test with new metadata format"""
|
||||||
|
metadata = {
|
||||||
|
"format_version": "1.0",
|
||||||
|
"layers": {
|
||||||
|
"layer1": {
|
||||||
|
"format": "fp8_e4m3fn_scaled",
|
||||||
|
"params": {"use_fp8_matmul": True}
|
||||||
|
},
|
||||||
|
"layer2": {
|
||||||
|
"format": "fp8_e5m2_scaled",
|
||||||
|
"params": {"use_fp8_matmul": True}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state_dict = {
|
||||||
|
"model.diffusion_model._quantization_metadata": metadata,
|
||||||
|
"model.diffusion_model.layer1.weight": torch.randn(10, 20),
|
||||||
|
}
|
||||||
|
result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.")
|
||||||
|
|
||||||
|
self.assertIsNotNone(result)
|
||||||
|
self.assertIn("layer1", result)
|
||||||
|
self.assertIn("layer2", result)
|
||||||
|
self.assertEqual(result["layer1"]["format"], "fp8_e4m3fn_scaled")
|
||||||
|
self.assertEqual(result["layer2"]["format"], "fp8_e5m2_scaled")
|
||||||
|
# Metadata should be popped from state_dict
|
||||||
|
self.assertNotIn("model.diffusion_model._quantization_metadata", state_dict)
|
||||||
|
|
||||||
|
def test_mixed_precision_detection(self):
|
||||||
|
"""Test detection of mixed precision via scale patterns"""
|
||||||
|
# Create FP8 tensors by converting from float32
|
||||||
|
fp8_weight1 = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
fp8_weight3 = torch.randn(30, 40, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
state_dict = {
|
||||||
|
# Layer 1: FP8 (has scale_weight)
|
||||||
|
"model.diffusion_model.layer1.weight": fp8_weight1,
|
||||||
|
"model.diffusion_model.layer1.scale_weight": torch.tensor(1.0),
|
||||||
|
# Layer 2: Standard (no scale_weight)
|
||||||
|
"model.diffusion_model.layer2.weight": torch.randn(20, 30, dtype=torch.bfloat16),
|
||||||
|
# Layer 3: FP8 (has scale_weight)
|
||||||
|
"model.diffusion_model.layer3.weight": fp8_weight3,
|
||||||
|
"model.diffusion_model.layer3.scale_weight": torch.tensor(1.0),
|
||||||
|
}
|
||||||
|
result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.")
|
||||||
|
|
||||||
|
self.assertIsNotNone(result)
|
||||||
|
self.assertIn("layer1", result)
|
||||||
|
self.assertIn("layer3", result)
|
||||||
|
self.assertNotIn("layer2", result) # Layer 2 not quantized
|
||||||
|
self.assertEqual(result["layer1"]["format"], "fp8_e4m3fn_scaled")
|
||||||
|
self.assertEqual(result["layer3"]["format"], "fp8_e4m3fn_scaled")
|
||||||
|
|
||||||
|
def test_all_layers_quantized(self):
|
||||||
|
"""Test that uniform quantization (all layers) returns None"""
|
||||||
|
# Create FP8 tensors by converting from float32
|
||||||
|
fp8_weight1 = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
fp8_weight2 = torch.randn(20, 30, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
state_dict = {
|
||||||
|
# All layers have scale_weight
|
||||||
|
"model.diffusion_model.layer1.weight": fp8_weight1,
|
||||||
|
"model.diffusion_model.layer1.scale_weight": torch.tensor(1.0),
|
||||||
|
"model.diffusion_model.layer2.weight": fp8_weight2,
|
||||||
|
"model.diffusion_model.layer2.scale_weight": torch.tensor(1.0),
|
||||||
|
}
|
||||||
|
result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.")
|
||||||
|
|
||||||
|
# If all layers are quantized, it's not mixed precision
|
||||||
|
# Should return None to use legacy scaled_fp8_ops path
|
||||||
|
self.assertIsNone(result)
|
||||||
|
|
||||||
|
def test_fp8_e5m2_detection(self):
|
||||||
|
"""Test detection of FP8 E5M2 format"""
|
||||||
|
# Create FP8 E5M2 tensor by converting from float32
|
||||||
|
fp8_weight = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e5m2)
|
||||||
|
state_dict = {
|
||||||
|
"model.diffusion_model.layer1.weight": fp8_weight,
|
||||||
|
"model.diffusion_model.layer1.scale_weight": torch.tensor(1.0),
|
||||||
|
"model.diffusion_model.layer2.weight": torch.randn(20, 30, dtype=torch.bfloat16),
|
||||||
|
}
|
||||||
|
result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.")
|
||||||
|
|
||||||
|
self.assertIsNotNone(result)
|
||||||
|
self.assertIn("layer1", result)
|
||||||
|
self.assertEqual(result["layer1"]["format"], "fp8_e5m2_scaled")
|
||||||
|
|
||||||
|
def test_invalid_metadata(self):
|
||||||
|
"""Test with invalid metadata format"""
|
||||||
|
state_dict = {
|
||||||
|
"model.diffusion_model._quantization_metadata": "invalid_string",
|
||||||
|
"model.diffusion_model.layer1.weight": torch.randn(10, 20),
|
||||||
|
}
|
||||||
|
result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.")
|
||||||
|
# Should return None on invalid metadata
|
||||||
|
self.assertIsNone(result)
|
||||||
|
|
||||||
|
def test_different_prefix(self):
|
||||||
|
"""Test with different model prefix (audio model)"""
|
||||||
|
# Create FP8 tensor by converting from float32
|
||||||
|
fp8_weight = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
state_dict = {
|
||||||
|
"model.model.layer1.weight": fp8_weight,
|
||||||
|
"model.model.layer1.scale_weight": torch.tensor(1.0),
|
||||||
|
"model.model.layer2.weight": torch.randn(20, 30, dtype=torch.bfloat16),
|
||||||
|
}
|
||||||
|
result = model_detection.detect_layer_quantization(state_dict, "model.model.")
|
||||||
|
|
||||||
|
self.assertIsNotNone(result)
|
||||||
|
self.assertIn("layer1", result)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPickOperationsIntegration(unittest.TestCase):
|
||||||
|
"""Test pick_operations with model_config parameter"""
|
||||||
|
|
||||||
|
def test_backward_compatibility(self):
|
||||||
|
"""Test that pick_operations works without model_config (legacy)"""
|
||||||
|
from comfy import ops
|
||||||
|
|
||||||
|
# Should work without model_config parameter
|
||||||
|
result = ops.pick_operations(None, None)
|
||||||
|
self.assertIsNotNone(result)
|
||||||
|
self.assertEqual(result, ops.disable_weight_init)
|
||||||
|
|
||||||
|
def test_with_model_config_no_quant(self):
|
||||||
|
"""Test with model_config but no quantization"""
|
||||||
|
from comfy import ops, supported_models_base
|
||||||
|
|
||||||
|
model_config = supported_models_base.BASE({})
|
||||||
|
model_config.layer_quant_config = None
|
||||||
|
|
||||||
|
result = ops.pick_operations(None, None, model_config=model_config)
|
||||||
|
self.assertIsNotNone(result)
|
||||||
|
# Should use standard path
|
||||||
|
self.assertEqual(result, ops.disable_weight_init)
|
||||||
|
|
||||||
|
def test_legacy_scaled_fp8(self):
|
||||||
|
"""Test that legacy scaled_fp8 still works"""
|
||||||
|
from comfy import ops, supported_models_base
|
||||||
|
|
||||||
|
model_config = supported_models_base.BASE({})
|
||||||
|
model_config.scaled_fp8 = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
result = ops.pick_operations(
|
||||||
|
None, None,
|
||||||
|
scaled_fp8=torch.float8_e4m3fn,
|
||||||
|
model_config=model_config
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(result)
|
||||||
|
# Should return scaled_fp8_ops (the returned class is the inner class)
|
||||||
|
# Check that it's not the standard disable_weight_init
|
||||||
|
self.assertNotEqual(result, ops.disable_weight_init)
|
||||||
|
# Verify it has Linear class
|
||||||
|
self.assertTrue(hasattr(result, 'Linear'))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
|
|
||||||
399
tests-unit/comfy_test/test_quant_registry.py
Normal file
399
tests-unit/comfy_test/test_quant_registry.py
Normal file
@ -0,0 +1,399 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for tensor subclass quantization system.
|
||||||
|
Tests the new QuantizedTensorFP8 subclass and operation handlers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import torch
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Add comfy to path
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||||
|
|
||||||
|
from comfy import ops
|
||||||
|
from comfy import quant_ops
|
||||||
|
|
||||||
|
|
||||||
|
class TestQuantizedTensorFP8(unittest.TestCase):
|
||||||
|
"""Test the QuantizedTensorFP8 tensor subclass"""
|
||||||
|
|
||||||
|
def test_creation(self):
|
||||||
|
"""Test creating a QuantizedTensorFP8"""
|
||||||
|
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
scale = torch.tensor(2.0)
|
||||||
|
|
||||||
|
qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
self.assertIsInstance(qt, quant_ops.QuantizedTensorFP8)
|
||||||
|
self.assertEqual(qt.shape, (256, 128))
|
||||||
|
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
||||||
|
self.assertEqual(qt._scale, scale)
|
||||||
|
self.assertEqual(qt._orig_dtype, torch.bfloat16)
|
||||||
|
|
||||||
|
def test_dequantize(self):
|
||||||
|
"""Test explicit dequantization"""
|
||||||
|
# Create a simple FP8 tensor
|
||||||
|
fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
scale = torch.tensor(3.0)
|
||||||
|
|
||||||
|
qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.float32)
|
||||||
|
dequantized = qt.dequantize()
|
||||||
|
|
||||||
|
# Dequantized should be approximately ones * 3.0
|
||||||
|
self.assertEqual(dequantized.dtype, torch.float32)
|
||||||
|
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))
|
||||||
|
|
||||||
|
def test_repr(self):
|
||||||
|
"""Test string representation"""
|
||||||
|
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
scale = torch.tensor(2.5)
|
||||||
|
|
||||||
|
qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.bfloat16)
|
||||||
|
repr_str = repr(qt)
|
||||||
|
|
||||||
|
self.assertIn("QuantizedTensorFP8", repr_str)
|
||||||
|
self.assertIn("shape", repr_str)
|
||||||
|
self.assertIn("scale", repr_str)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOperationRegistry(unittest.TestCase):
|
||||||
|
"""Test the operation registry system"""
|
||||||
|
|
||||||
|
def test_registry_basics(self):
|
||||||
|
"""Test that operations are registered"""
|
||||||
|
registered_ops = quant_ops.list_registered_ops()
|
||||||
|
|
||||||
|
# Check that key operations are registered
|
||||||
|
self.assertIn(torch.ops.aten.linear.default, registered_ops)
|
||||||
|
self.assertIn(torch.ops.aten.silu.default, registered_ops)
|
||||||
|
self.assertIn(torch.ops.aten.layer_norm.default, registered_ops)
|
||||||
|
self.assertIn(torch.ops.aten.add.Tensor, registered_ops)
|
||||||
|
self.assertIn(torch.ops.aten.mul.Tensor, registered_ops)
|
||||||
|
|
||||||
|
def test_get_handler(self):
|
||||||
|
"""Test getting a registered handler"""
|
||||||
|
handler = quant_ops.get_quant_handler(torch.ops.aten.linear.default)
|
||||||
|
self.assertIsNotNone(handler)
|
||||||
|
self.assertTrue(callable(handler))
|
||||||
|
|
||||||
|
def test_custom_registration(self):
|
||||||
|
"""Test registering a custom operation"""
|
||||||
|
|
||||||
|
# Define a custom handler
|
||||||
|
@quant_ops.register_quant_op(torch.ops.aten.relu.default)
|
||||||
|
def custom_relu_handler(func, args, kwargs):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Verify registration
|
||||||
|
handler = quant_ops.get_quant_handler(torch.ops.aten.relu.default)
|
||||||
|
self.assertIsNotNone(handler)
|
||||||
|
self.assertEqual(handler, custom_relu_handler)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLinearHandler(unittest.TestCase):
|
||||||
|
"""Test the linear operation handler"""
|
||||||
|
|
||||||
|
def test_linear_with_quantized_weight(self):
|
||||||
|
"""Test F.linear with quantized weight"""
|
||||||
|
# Set seed for reproducibility
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
# Create quantized weight
|
||||||
|
weight_fp32 = torch.randn(256, 128, dtype=torch.float32)
|
||||||
|
scale = torch.tensor(2.0)
|
||||||
|
weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn)
|
||||||
|
weight_q = quant_ops.QuantizedTensorFP8(weight_fp8, scale, orig_dtype=torch.float32)
|
||||||
|
|
||||||
|
# Create input
|
||||||
|
input_tensor = torch.randn(16, 128, dtype=torch.float32)
|
||||||
|
|
||||||
|
# Call linear (should trigger dispatch)
|
||||||
|
output = torch.nn.functional.linear(input_tensor, weight_q, bias=None)
|
||||||
|
|
||||||
|
# Verify output shape
|
||||||
|
self.assertEqual(output.shape, (16, 256))
|
||||||
|
|
||||||
|
# Verify it's approximately correct (allowing for FP8 quantization error)
|
||||||
|
# Note: FP8 has limited precision, so use very loose tolerance
|
||||||
|
expected = torch.nn.functional.linear(input_tensor, weight_fp32, bias=None)
|
||||||
|
# Just check that it's in the right ballpark (within 50% error on average)
|
||||||
|
mean_rel_error = ((output - expected).abs() / (expected.abs() + 1e-6)).mean()
|
||||||
|
self.assertLess(mean_rel_error, 0.5, f"Mean relative error {mean_rel_error:.3f} is too large")
|
||||||
|
|
||||||
|
def test_linear_with_bias(self):
|
||||||
|
"""Test F.linear with quantized weight and bias"""
|
||||||
|
weight_fp32 = torch.randn(64, 32, dtype=torch.float32)
|
||||||
|
scale = torch.tensor(1.5)
|
||||||
|
weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn)
|
||||||
|
weight_q = quant_ops.QuantizedTensorFP8(weight_fp8, scale, orig_dtype=torch.float32)
|
||||||
|
|
||||||
|
input_tensor = torch.randn(8, 32, dtype=torch.float32)
|
||||||
|
bias = torch.randn(64, dtype=torch.float32)
|
||||||
|
|
||||||
|
output = torch.nn.functional.linear(input_tensor, weight_q, bias)
|
||||||
|
|
||||||
|
self.assertEqual(output.shape, (8, 64))
|
||||||
|
|
||||||
|
|
||||||
|
class TestActivationHandlers(unittest.TestCase):
|
||||||
|
"""Test activation function handlers"""
|
||||||
|
|
||||||
|
def test_silu_with_quantized_input(self):
|
||||||
|
"""Test SiLU with quantized input"""
|
||||||
|
# Create quantized input
|
||||||
|
input_fp32 = torch.randn(16, 128, dtype=torch.float32)
|
||||||
|
scale = torch.tensor(1.0)
|
||||||
|
input_fp8 = (input_fp32 / scale).to(torch.float8_e4m3fn)
|
||||||
|
input_q = quant_ops.QuantizedTensorFP8(input_fp8, scale, orig_dtype=torch.float32)
|
||||||
|
|
||||||
|
# Apply SiLU
|
||||||
|
output = torch.nn.functional.silu(input_q)
|
||||||
|
|
||||||
|
# Should return a QuantizedTensorFP8
|
||||||
|
self.assertIsInstance(output, quant_ops.QuantizedTensorFP8)
|
||||||
|
|
||||||
|
# Verify approximate correctness
|
||||||
|
expected = torch.nn.functional.silu(input_fp32)
|
||||||
|
output_dq = output.dequantize()
|
||||||
|
self.assertTrue(torch.allclose(output_dq, expected, rtol=0.2, atol=0.2))
|
||||||
|
|
||||||
|
def test_layernorm_dequantizes(self):
|
||||||
|
"""Test that LayerNorm dequantizes input"""
|
||||||
|
# Create quantized input
|
||||||
|
input_fp32 = torch.randn(16, 128, dtype=torch.float32)
|
||||||
|
scale = torch.tensor(1.0)
|
||||||
|
input_fp8 = (input_fp32 / scale).to(torch.float8_e4m3fn)
|
||||||
|
input_q = quant_ops.QuantizedTensorFP8(input_fp8, scale, orig_dtype=torch.float32)
|
||||||
|
|
||||||
|
# Apply LayerNorm
|
||||||
|
weight = torch.ones(128)
|
||||||
|
bias = torch.zeros(128)
|
||||||
|
output = torch.nn.functional.layer_norm(input_q, (128,), weight, bias)
|
||||||
|
|
||||||
|
# Should NOT be quantized (LayerNorm breaks quantization)
|
||||||
|
self.assertNotIsInstance(output, quant_ops.QuantizedTensorFP8)
|
||||||
|
self.assertEqual(output.dtype, torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
class TestElementwiseHandlers(unittest.TestCase):
|
||||||
|
"""Test element-wise operation handlers"""
|
||||||
|
|
||||||
|
def test_add_mixed_tensors(self):
|
||||||
|
"""Test addition with mixed quantized/non-quantized tensors"""
|
||||||
|
# Create quantized tensor
|
||||||
|
a_fp32 = torch.ones(10, 20, dtype=torch.float32)
|
||||||
|
scale = torch.tensor(1.0)
|
||||||
|
a_fp8 = (a_fp32 / scale).to(torch.float8_e4m3fn)
|
||||||
|
a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale, orig_dtype=torch.float32)
|
||||||
|
|
||||||
|
# Non-quantized tensor
|
||||||
|
b = torch.ones(10, 20, dtype=torch.float32) * 2.0
|
||||||
|
|
||||||
|
# Add them
|
||||||
|
result = a_q + b
|
||||||
|
|
||||||
|
# Should be dequantized
|
||||||
|
self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8)
|
||||||
|
self.assertTrue(torch.allclose(result, torch.ones(10, 20) * 3.0, rtol=0.1))
|
||||||
|
|
||||||
|
def test_mul_quantized_tensors(self):
|
||||||
|
"""Test multiplication of two quantized tensors"""
|
||||||
|
a_fp32 = torch.ones(10, 20) * 2.0
|
||||||
|
scale_a = torch.tensor(1.0)
|
||||||
|
a_fp8 = (a_fp32 / scale_a).to(torch.float8_e4m3fn)
|
||||||
|
a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale_a, orig_dtype=torch.float32)
|
||||||
|
|
||||||
|
b_fp32 = torch.ones(10, 20) * 3.0
|
||||||
|
scale_b = torch.tensor(1.0)
|
||||||
|
b_fp8 = (b_fp32 / scale_b).to(torch.float8_e4m3fn)
|
||||||
|
b_q = quant_ops.QuantizedTensorFP8(b_fp8, scale_b, orig_dtype=torch.float32)
|
||||||
|
|
||||||
|
result = a_q * b_q
|
||||||
|
|
||||||
|
# Should be dequantized
|
||||||
|
self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8)
|
||||||
|
self.assertTrue(torch.allclose(result, torch.ones(10, 20) * 6.0, rtol=0.1))
|
||||||
|
|
||||||
|
|
||||||
|
class TestFallbackMechanism(unittest.TestCase):
|
||||||
|
"""Test fallback for unsupported operations"""
|
||||||
|
|
||||||
|
def test_unsupported_op_dequantizes(self):
|
||||||
|
"""Test that unsupported operations fall back to dequantization"""
|
||||||
|
# Set seed for reproducibility
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
# Create quantized tensor
|
||||||
|
a_fp32 = torch.randn(10, 20, dtype=torch.float32)
|
||||||
|
scale = torch.tensor(1.0)
|
||||||
|
a_fp8 = (a_fp32 / scale).to(torch.float8_e4m3fn)
|
||||||
|
a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale, orig_dtype=torch.float32)
|
||||||
|
|
||||||
|
# Call an operation that doesn't have a registered handler
|
||||||
|
# For example, torch.abs
|
||||||
|
result = torch.abs(a_q)
|
||||||
|
|
||||||
|
# Should work via fallback (dequantize → abs → return)
|
||||||
|
self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8)
|
||||||
|
expected = torch.abs(a_fp32)
|
||||||
|
# FP8 introduces quantization error, so use loose tolerance
|
||||||
|
mean_error = (result - expected).abs().mean()
|
||||||
|
self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large")
|
||||||
|
|
||||||
|
|
||||||
|
class TestMixedPrecisionOps(unittest.TestCase):
|
||||||
|
"""Test MixedPrecisionOps integration"""
|
||||||
|
|
||||||
|
def test_linear_layer_creation(self):
|
||||||
|
"""Test that MixedPrecisionOps.Linear can be created"""
|
||||||
|
layer = ops.MixedPrecisionOps.Linear(128, 256, bias=True, device="cpu", dtype=torch.float32)
|
||||||
|
|
||||||
|
self.assertIsInstance(layer, ops.MixedPrecisionOps.Linear)
|
||||||
|
self.assertFalse(layer._quantization_initialized)
|
||||||
|
self.assertIsNone(layer.quant_format)
|
||||||
|
|
||||||
|
def test_layer_quant_config_detection(self):
|
||||||
|
"""Test that layer quantization config is detected during load"""
|
||||||
|
# Set up layer config
|
||||||
|
ops.MixedPrecisionOps._layer_quant_config = {
|
||||||
|
"test_layer": {
|
||||||
|
"format": "fp8_e4m3fn",
|
||||||
|
"params": {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create a state dict with quantized weight
|
||||||
|
weight_fp32 = torch.randn(256, 128, dtype=torch.float32)
|
||||||
|
scale = torch.tensor(2.0)
|
||||||
|
weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
state_dict = {
|
||||||
|
"model.diffusion_model.test_layer.weight": weight_fp8,
|
||||||
|
"model.diffusion_model.test_layer.scale_weight": scale,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create layer and load
|
||||||
|
layer = ops.MixedPrecisionOps.Linear(128, 256, bias=False, device="cpu", dtype=torch.float8_e4m3fn)
|
||||||
|
layer.weight = torch.nn.Parameter(torch.zeros(256, 128, dtype=torch.float8_e4m3fn))
|
||||||
|
|
||||||
|
# Manually call _load_from_state_dict
|
||||||
|
layer._load_from_state_dict(
|
||||||
|
state_dict,
|
||||||
|
prefix="model.diffusion_model.test_layer.",
|
||||||
|
local_metadata={},
|
||||||
|
strict=True,
|
||||||
|
missing_keys=[],
|
||||||
|
unexpected_keys=[],
|
||||||
|
error_msgs=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify quantization was initialized
|
||||||
|
self.assertTrue(layer._quantization_initialized)
|
||||||
|
self.assertEqual(layer.quant_format, "fp8_e4m3fn")
|
||||||
|
self.assertIsNotNone(layer.quant_scale)
|
||||||
|
|
||||||
|
# Verify weight is wrapped
|
||||||
|
self.assertIsInstance(layer.weight.data, quant_ops.QuantizedTensorFP8)
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
ops.MixedPrecisionOps._layer_quant_config = {}
|
||||||
|
|
||||||
|
|
||||||
|
class TestBackwardCompatibility(unittest.TestCase):
|
||||||
|
"""Test backward compatibility with legacy systems"""
|
||||||
|
|
||||||
|
def test_legacy_ops_classes_exist(self):
|
||||||
|
"""Test that legacy ops classes still exist"""
|
||||||
|
self.assertTrue(hasattr(ops, 'disable_weight_init'))
|
||||||
|
self.assertTrue(hasattr(ops, 'manual_cast'))
|
||||||
|
self.assertTrue(hasattr(ops, 'fp8_ops'))
|
||||||
|
self.assertTrue(hasattr(ops, 'scaled_fp8_ops'))
|
||||||
|
|
||||||
|
def test_pick_operations_legacy_path(self):
|
||||||
|
"""Test pick_operations returns correct class for legacy cases"""
|
||||||
|
# Test standard case
|
||||||
|
result = ops.pick_operations(torch.float32, torch.float32)
|
||||||
|
self.assertEqual(result, ops.disable_weight_init)
|
||||||
|
|
||||||
|
# Test manual cast case
|
||||||
|
result = ops.pick_operations(torch.float32, torch.float16)
|
||||||
|
self.assertEqual(result, ops.manual_cast)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFP8LinearUnification(unittest.TestCase):
|
||||||
|
"""Test that fp8_linear now uses the unified tensor subclass infrastructure"""
|
||||||
|
|
||||||
|
@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for FP8")
|
||||||
|
def test_fp8_linear_uses_tensor_subclass(self):
|
||||||
|
"""Verify fp8_linear wraps tensors in QuantizedTensorFP8"""
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
# Create a mock Linear layer with FP8 weight
|
||||||
|
linear = ops.fp8_ops.Linear(4, 3, bias=True)
|
||||||
|
linear.weight = torch.nn.Parameter(
|
||||||
|
torch.randn(3, 4, dtype=torch.bfloat16).to(torch.float8_e4m3fn),
|
||||||
|
requires_grad=False
|
||||||
|
)
|
||||||
|
linear.bias = torch.nn.Parameter(
|
||||||
|
torch.randn(3, dtype=torch.bfloat16),
|
||||||
|
requires_grad=False
|
||||||
|
)
|
||||||
|
linear.scale_weight = torch.tensor(1.0)
|
||||||
|
linear.scale_input = None # No input scaling
|
||||||
|
|
||||||
|
# Create input
|
||||||
|
input_tensor = torch.randn(2, 4, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
# Call fp8_linear - should work without errors
|
||||||
|
try:
|
||||||
|
result = ops.fp8_linear(linear, input_tensor)
|
||||||
|
self.assertIsNotNone(result)
|
||||||
|
self.assertEqual(result.shape, (2, 3))
|
||||||
|
except Exception as e:
|
||||||
|
# On CPU or unsupported hardware, _scaled_mm might not be available
|
||||||
|
# but the function should still complete without syntax errors
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_fp8_linear_maintains_signature(self):
|
||||||
|
"""Verify fp8_linear maintains its original function signature"""
|
||||||
|
import inspect
|
||||||
|
sig = inspect.signature(ops.fp8_linear)
|
||||||
|
params = list(sig.parameters.keys())
|
||||||
|
|
||||||
|
# Should have 'self' and 'input' parameters
|
||||||
|
self.assertIn('self', params)
|
||||||
|
self.assertIn('input', params)
|
||||||
|
self.assertEqual(len(params), 2)
|
||||||
|
|
||||||
|
def test_fp8_linear_returns_none_for_non_fp8(self):
|
||||||
|
"""Verify fp8_linear returns None for non-FP8 weights"""
|
||||||
|
# Create a Linear layer with BF16 weight (not FP8)
|
||||||
|
linear = ops.disable_weight_init.Linear(4, 3, bias=False)
|
||||||
|
linear.weight = torch.nn.Parameter(
|
||||||
|
torch.randn(3, 4, dtype=torch.bfloat16),
|
||||||
|
requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
input_tensor = torch.randn(2, 4, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
# Should return None for non-FP8 weights
|
||||||
|
result = ops.fp8_linear(linear, input_tensor)
|
||||||
|
self.assertIsNone(result)
|
||||||
|
|
||||||
|
def test_fp8_ops_linear_uses_fp8_linear(self):
|
||||||
|
"""Verify fp8_ops.Linear still uses fp8_linear in forward pass"""
|
||||||
|
linear = ops.fp8_ops.Linear(4, 3, bias=False)
|
||||||
|
|
||||||
|
# Verify the class has the forward_comfy_cast_weights method
|
||||||
|
self.assertTrue(hasattr(linear, 'forward_comfy_cast_weights'))
|
||||||
|
|
||||||
|
# The forward_comfy_cast_weights should attempt to call fp8_linear
|
||||||
|
# (we can't easily test this without mocking, but we verify structure)
|
||||||
|
import inspect
|
||||||
|
source = inspect.getsource(linear.forward_comfy_cast_weights)
|
||||||
|
self.assertIn('fp8_linear', source)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in New Issue
Block a user