ruff lint

This commit is contained in:
lspindler 2025-10-27 08:52:50 +01:00
parent a7216e18e5
commit 2a8b826442
6 changed files with 117 additions and 118 deletions

View File

@ -326,7 +326,7 @@ class BaseModel(torch.nn.Module):
if self.model_config.scaled_fp8 is not None: if self.model_config.scaled_fp8 is not None:
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8) unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
# Save mixed precision metadata # Save mixed precision metadata
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config: if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
metadata = { metadata = {

View File

@ -16,7 +16,7 @@ def detect_layer_quantization(metadata):
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})") logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
return quant_metadata["layers"] return quant_metadata["layers"]
else: else:
raise ValueError(f"Invalid quantization metadata format") raise ValueError("Invalid quantization metadata format")
return None return None

View File

@ -325,7 +325,7 @@ class manual_cast(disable_weight_init):
def fp8_linear(self, input): def fp8_linear(self, input):
""" """
Legacy FP8 linear function for backward compatibility. Legacy FP8 linear function for backward compatibility.
Uses QuantizedTensor subclass for dispatch. Uses QuantizedTensor subclass for dispatch.
""" """
dtype = self.weight.dtype dtype = self.weight.dtype
@ -339,7 +339,7 @@ def fp8_linear(self, input):
input_shape = input.shape input_shape = input.shape
input_dtype = input.dtype input_dtype = input.dtype
if len(input.shape) == 3: if len(input.shape) == 3:
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
@ -354,14 +354,14 @@ def fp8_linear(self, input):
scale_input = torch.ones((), device=input.device, dtype=torch.float32) scale_input = torch.ones((), device=input.device, dtype=torch.float32)
else: else:
scale_input = scale_input.to(input.device) scale_input = scale_input.to(input.device)
# Wrap weight in QuantizedTensor - this enables unified dispatch # Wrap weight in QuantizedTensor - this enables unified dispatch
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype} layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight) quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, fp8_dtype=dtype) quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, fp8_dtype=dtype)
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
if tensor_2d: if tensor_2d:
return o.reshape(input_shape[0], -1) return o.reshape(input_shape[0], -1)
return o.reshape((-1, input_shape[1], self.weight.shape[0])) return o.reshape((-1, input_shape[1], self.weight.shape[0]))
@ -503,8 +503,8 @@ class MixedPrecisionOps(disable_weight_init):
def reset_parameters(self): def reset_parameters(self):
return None return None
def _load_from_state_dict(self, state_dict, prefix, local_metadata, def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs): strict, missing_keys, unexpected_keys, error_msgs):
device = self.factory_kwargs["device"] device = self.factory_kwargs["device"]
@ -520,10 +520,10 @@ class MixedPrecisionOps(disable_weight_init):
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None) quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
if quant_format is None: if quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}") raise ValueError(f"Unknown quantization format for layer {layer_name}")
mixin = QUANT_FORMAT_MIXINS[quant_format] mixin = QUANT_FORMAT_MIXINS[quant_format]
self.layout_type = mixin["layout_type"] self.layout_type = mixin["layout_type"]
layout_params = { layout_params = {
'scale': state_dict.pop(f"{prefix}weight_scale", None), 'scale': state_dict.pop(f"{prefix}weight_scale", None),
'orig_dtype': MixedPrecisionOps._compute_dtype 'orig_dtype': MixedPrecisionOps._compute_dtype
@ -558,7 +558,7 @@ class MixedPrecisionOps(disable_weight_init):
not isinstance(input, QuantizedTensor)): not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype) input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias) return self._forward(input, self.weight, self.bias)
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
@ -566,7 +566,7 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_
MixedPrecisionOps._compute_dtype = compute_dtype MixedPrecisionOps._compute_dtype = compute_dtype
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
return MixedPrecisionOps return MixedPrecisionOps
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8 is not None: if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)

View File

@ -31,7 +31,7 @@ def register_generic_util(torch_op):
Decorator to register a generic utility that works for all layouts. Decorator to register a generic utility that works for all layouts.
Args: Args:
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default) torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
Example: Example:
@register_generic_util(torch.ops.aten.detach.default) @register_generic_util(torch.ops.aten.detach.default)
def generic_detach(func, args, kwargs): def generic_detach(func, args, kwargs):
@ -78,10 +78,10 @@ def _copy_layout_params(params):
class QuantizedLayout: class QuantizedLayout:
""" """
Base class for quantization layouts. Base class for quantization layouts.
A layout encapsulates the format-specific logic for quantization/dequantization A layout encapsulates the format-specific logic for quantization/dequantization
and provides a uniform interface for extracting raw tensors needed for computation. and provides a uniform interface for extracting raw tensors needed for computation.
New quantization formats should subclass this and implement the required methods. New quantization formats should subclass this and implement the required methods.
""" """
@classmethod @classmethod
@ -90,8 +90,8 @@ class QuantizedLayout:
@staticmethod @staticmethod
def dequantize(qdata, **layout_params) -> torch.Tensor: def dequantize(qdata, **layout_params) -> torch.Tensor:
raise NotImplementedError(f"TensorLayout must implement dequantize()") raise NotImplementedError("TensorLayout must implement dequantize()")
@classmethod @classmethod
def get_plain_tensors(cls, qtensor) -> torch.Tensor: def get_plain_tensors(cls, qtensor) -> torch.Tensor:
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()") raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
@ -100,45 +100,45 @@ class QuantizedLayout:
class QuantizedTensor(torch.Tensor): class QuantizedTensor(torch.Tensor):
""" """
Universal quantized tensor that works with any layout. Universal quantized tensor that works with any layout.
This tensor subclass uses a pluggable layout system to support multiple This tensor subclass uses a pluggable layout system to support multiple
quantization formats (FP8, INT4, INT8, etc.) without code duplication. quantization formats (FP8, INT4, INT8, etc.) without code duplication.
The layout_type determines format-specific behavior, while common operations The layout_type determines format-specific behavior, while common operations
(detach, clone, to) are handled generically. (detach, clone, to) are handled generically.
Attributes: Attributes:
_qdata: The quantized tensor data _qdata: The quantized tensor data
_layout_type: Layout class (e.g., TensorCoreFP8Layout) _layout_type: Layout class (e.g., TensorCoreFP8Layout)
_layout_params: Dict with layout-specific params (scale, zero_point, etc.) _layout_params: Dict with layout-specific params (scale, zero_point, etc.)
""" """
@staticmethod @staticmethod
def __new__(cls, qdata, layout_type, layout_params): def __new__(cls, qdata, layout_type, layout_params):
""" """
Create a quantized tensor. Create a quantized tensor.
Args: Args:
qdata: The quantized data tensor qdata: The quantized data tensor
layout_type: Layout class (subclass of QuantizedLayout) layout_type: Layout class (subclass of QuantizedLayout)
layout_params: Dict with layout-specific parameters layout_params: Dict with layout-specific parameters
""" """
return torch.Tensor._make_subclass(cls, qdata, require_grad=False) return torch.Tensor._make_subclass(cls, qdata, require_grad=False)
def __init__(self, qdata, layout_type, layout_params): def __init__(self, qdata, layout_type, layout_params):
self._qdata = qdata.contiguous() self._qdata = qdata.contiguous()
self._layout_type = layout_type self._layout_type = layout_type
self._layout_params = layout_params self._layout_params = layout_params
def __repr__(self): def __repr__(self):
layout_name = self._layout_type.__name__ layout_name = self._layout_type.__name__
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2]) param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})" return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
@property @property
def layout_type(self): def layout_type(self):
return self._layout_type return self._layout_type
def __tensor_flatten__(self): def __tensor_flatten__(self):
""" """
Tensor flattening protocol for proper device movement. Tensor flattening protocol for proper device movement.
@ -147,7 +147,7 @@ class QuantizedTensor(torch.Tensor):
ctx = { ctx = {
"layout_type": self._layout_type, "layout_type": self._layout_type,
} }
tensor_params = {} tensor_params = {}
non_tensor_params = {} non_tensor_params = {}
for k, v in self._layout_params.items(): for k, v in self._layout_params.items():
@ -155,17 +155,17 @@ class QuantizedTensor(torch.Tensor):
tensor_params[k] = v tensor_params[k] = v
else: else:
non_tensor_params[k] = v non_tensor_params[k] = v
ctx["tensor_param_keys"] = list(tensor_params.keys()) ctx["tensor_param_keys"] = list(tensor_params.keys())
ctx["non_tensor_params"] = non_tensor_params ctx["non_tensor_params"] = non_tensor_params
for k, v in tensor_params.items(): for k, v in tensor_params.items():
attr_name = f"_layout_param_{k}" attr_name = f"_layout_param_{k}"
object.__setattr__(self, attr_name, v) object.__setattr__(self, attr_name, v)
inner_tensors.append(attr_name) inner_tensors.append(attr_name)
return inner_tensors, ctx return inner_tensors, ctx
@staticmethod @staticmethod
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
""" """
@ -174,41 +174,41 @@ class QuantizedTensor(torch.Tensor):
""" """
layout_type = ctx["layout_type"] layout_type = ctx["layout_type"]
layout_params = dict(ctx["non_tensor_params"]) layout_params = dict(ctx["non_tensor_params"])
for key in ctx["tensor_param_keys"]: for key in ctx["tensor_param_keys"]:
attr_name = f"_layout_param_{key}" attr_name = f"_layout_param_{key}"
layout_params[key] = inner_tensors[attr_name] layout_params[key] = inner_tensors[attr_name]
return QuantizedTensor(inner_tensors["_q_data"], layout_type, layout_params) return QuantizedTensor(inner_tensors["_q_data"], layout_type, layout_params)
@classmethod @classmethod
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor': def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs) qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs)
return cls(qdata, layout_type, layout_params) return cls(qdata, layout_type, layout_params)
def dequantize(self) -> torch.Tensor: def dequantize(self) -> torch.Tensor:
return self._layout_type.dequantize(self._qdata, **self._layout_params) return self._layout_type.dequantize(self._qdata, **self._layout_params)
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
kwargs = kwargs or {} kwargs = kwargs or {}
# Step 1: Check generic utilities first (detach, clone, to, etc.) # Step 1: Check generic utilities first (detach, clone, to, etc.)
if func in _GENERIC_UTILS: if func in _GENERIC_UTILS:
return _GENERIC_UTILS[func](func, args, kwargs) return _GENERIC_UTILS[func](func, args, kwargs)
# Step 2: Check layout-specific handlers (linear, matmul, etc.) # Step 2: Check layout-specific handlers (linear, matmul, etc.)
layout_type = _get_layout_from_args(args) layout_type = _get_layout_from_args(args)
if layout_type and func in _LAYOUT_REGISTRY: if layout_type and func in _LAYOUT_REGISTRY:
handler = _LAYOUT_REGISTRY[func].get(layout_type) handler = _LAYOUT_REGISTRY[func].get(layout_type)
if handler: if handler:
return handler(func, args, kwargs) return handler(func, args, kwargs)
# Step 3: Fallback to dequantization # Step 3: Fallback to dequantization
if isinstance(args[0] if args else None, QuantizedTensor): if isinstance(args[0] if args else None, QuantizedTensor):
logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
return cls._dequant_and_fallback(func, args, kwargs) return cls._dequant_and_fallback(func, args, kwargs)
@classmethod @classmethod
def _dequant_and_fallback(cls, func, args, kwargs): def _dequant_and_fallback(cls, func, args, kwargs):
def dequant_arg(arg): def dequant_arg(arg):
@ -217,7 +217,7 @@ class QuantizedTensor(torch.Tensor):
elif isinstance(arg, (list, tuple)): elif isinstance(arg, (list, tuple)):
return type(arg)(dequant_arg(a) for a in arg) return type(arg)(dequant_arg(a) for a in arg)
return arg return arg
new_args = dequant_arg(args) new_args = dequant_arg(args)
new_kwargs = dequant_arg(kwargs) new_kwargs = dequant_arg(kwargs)
return func(*new_args, **new_kwargs) return func(*new_args, **new_kwargs)
@ -239,13 +239,13 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=
f"QuantizedTensor: dtype conversion requested to {target_dtype}, " f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
f"but not supported for quantized tensors. Ignoring dtype." f"but not supported for quantized tensors. Ignoring dtype."
) )
if target_layout is not None and target_layout != torch.strided: if target_layout is not None and target_layout != torch.strided:
logging.warning( logging.warning(
f"QuantizedTensor: layout change requested to {target_layout}, " f"QuantizedTensor: layout change requested to {target_layout}, "
f"but not supported. Ignoring layout." f"but not supported. Ignoring layout."
) )
# Handle device transfer # Handle device transfer
current_device = qt._qdata.device current_device = qt._qdata.device
if target_device is not None: if target_device is not None:
@ -254,7 +254,7 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=
target_device = torch.device(target_device) target_device = torch.device(target_device)
if isinstance(current_device, str): if isinstance(current_device, str):
current_device = torch.device(current_device) current_device = torch.device(current_device)
if target_device != current_device: if target_device != current_device:
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}") logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
new_q_data = qt._qdata.to(device=target_device) new_q_data = qt._qdata.to(device=target_device)
@ -262,7 +262,7 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params) new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}") logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
return new_qt return new_qt
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original") logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
return qt return qt
@ -318,7 +318,7 @@ def generic_to_dtype_layout(func, args, kwargs):
def generic_copy_(func, args, kwargs): def generic_copy_(func, args, kwargs):
qt_dest = args[0] qt_dest = args[0]
src = args[1] src = args[1]
if isinstance(qt_dest, QuantizedTensor): if isinstance(qt_dest, QuantizedTensor):
if isinstance(src, QuantizedTensor): if isinstance(src, QuantizedTensor):
# Copy from another quantized tensor # Copy from another quantized tensor
@ -383,15 +383,15 @@ def fp8_linear(func, args, kwargs):
input_tensor = args[0] input_tensor = args[0]
weight = args[1] weight = args[1]
bias = args[2] if len(args) > 2 else None bias = args[2] if len(args) > 2 else None
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
out_dtype = kwargs.get("out_dtype") out_dtype = kwargs.get("out_dtype")
if out_dtype is None: if out_dtype is None:
out_dtype = input_tensor._layout_params['orig_dtype'] out_dtype = input_tensor._layout_params['orig_dtype']
weight_t = plain_weight.t() weight_t = plain_weight.t()
tensor_2d = False tensor_2d = False
@ -424,7 +424,7 @@ def fp8_linear(func, args, kwargs):
return QuantizedTensor(output, TensorCoreFP8Layout, output_params) return QuantizedTensor(output, TensorCoreFP8Layout, output_params)
else: else:
return output return output
except Exception as e: except Exception as e:
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")

View File

@ -16,7 +16,7 @@ class SimpleModel(torch.nn.Module):
self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16) self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16)
self.layer2 = operations.Linear(20, 30, device="cpu", dtype=torch.bfloat16) self.layer2 = operations.Linear(20, 30, device="cpu", dtype=torch.bfloat16)
self.layer3 = operations.Linear(30, 40, device="cpu", dtype=torch.bfloat16) self.layer3 = operations.Linear(30, 40, device="cpu", dtype=torch.bfloat16)
def forward(self, x): def forward(self, x):
x = self.layer1(x) x = self.layer1(x)
x = torch.nn.functional.relu(x) x = torch.nn.functional.relu(x)
@ -32,10 +32,10 @@ class TestMixedPrecisionOps(unittest.TestCase):
"""Test that model with no quantization works normally""" """Test that model with no quantization works normally"""
# Configure no quantization # Configure no quantization
ops.MixedPrecisionOps._layer_quant_config = {} ops.MixedPrecisionOps._layer_quant_config = {}
# Create model # Create model
model = SimpleModel(operations=ops.MixedPrecisionOps) model = SimpleModel(operations=ops.MixedPrecisionOps)
# Initialize weights manually # Initialize weights manually
model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16)) model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16))
model.layer1.bias = torch.nn.Parameter(torch.randn(20, dtype=torch.bfloat16)) model.layer1.bias = torch.nn.Parameter(torch.randn(20, dtype=torch.bfloat16))
@ -43,19 +43,19 @@ class TestMixedPrecisionOps(unittest.TestCase):
model.layer2.bias = torch.nn.Parameter(torch.randn(30, dtype=torch.bfloat16)) model.layer2.bias = torch.nn.Parameter(torch.randn(30, dtype=torch.bfloat16))
model.layer3.weight = torch.nn.Parameter(torch.randn(40, 30, dtype=torch.bfloat16)) model.layer3.weight = torch.nn.Parameter(torch.randn(40, 30, dtype=torch.bfloat16))
model.layer3.bias = torch.nn.Parameter(torch.randn(40, dtype=torch.bfloat16)) model.layer3.bias = torch.nn.Parameter(torch.randn(40, dtype=torch.bfloat16))
# Initialize weight_function and bias_function # Initialize weight_function and bias_function
for layer in [model.layer1, model.layer2, model.layer3]: for layer in [model.layer1, model.layer2, model.layer3]:
layer.weight_function = [] layer.weight_function = []
layer.bias_function = [] layer.bias_function = []
# Forward pass # Forward pass
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
output = model(input_tensor) output = model(input_tensor)
self.assertEqual(output.shape, (5, 40)) self.assertEqual(output.shape, (5, 40))
self.assertEqual(output.dtype, torch.bfloat16) self.assertEqual(output.dtype, torch.bfloat16)
def test_mixed_precision_load(self): def test_mixed_precision_load(self):
"""Test loading a mixed precision model from state dict""" """Test loading a mixed precision model from state dict"""
# Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard # Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard
@ -70,52 +70,52 @@ class TestMixedPrecisionOps(unittest.TestCase):
} }
} }
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create state dict with mixed precision # Create state dict with mixed precision
fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e4m3fn) fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e4m3fn)
state_dict = { state_dict = {
# Layer 1: FP8 E4M3FN # Layer 1: FP8 E4M3FN
"layer1.weight": fp8_weight1, "layer1.weight": fp8_weight1,
"layer1.bias": torch.randn(20, dtype=torch.bfloat16), "layer1.bias": torch.randn(20, dtype=torch.bfloat16),
"layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32), "layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32),
# Layer 2: Standard BF16 # Layer 2: Standard BF16
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
"layer2.bias": torch.randn(30, dtype=torch.bfloat16), "layer2.bias": torch.randn(30, dtype=torch.bfloat16),
# Layer 3: FP8 E4M3FN # Layer 3: FP8 E4M3FN
"layer3.weight": fp8_weight3, "layer3.weight": fp8_weight3,
"layer3.bias": torch.randn(40, dtype=torch.bfloat16), "layer3.bias": torch.randn(40, dtype=torch.bfloat16),
"layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32), "layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32),
} }
# Create model and load state dict (strict=False because custom loading pops keys) # Create model and load state dict (strict=False because custom loading pops keys)
model = SimpleModel(operations=ops.MixedPrecisionOps) model = SimpleModel(operations=ops.MixedPrecisionOps)
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
# Verify weights are wrapped in QuantizedTensor # Verify weights are wrapped in QuantizedTensor
self.assertIsInstance(model.layer1.weight, QuantizedTensor) self.assertIsInstance(model.layer1.weight, QuantizedTensor)
self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout) self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout)
# Layer 2 should NOT be quantized # Layer 2 should NOT be quantized
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor) self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
# Layer 3 should be quantized # Layer 3 should be quantized
self.assertIsInstance(model.layer3.weight, QuantizedTensor) self.assertIsInstance(model.layer3.weight, QuantizedTensor)
self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout) self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout)
# Verify scales were loaded # Verify scales were loaded
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0) self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5) self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5)
# Forward pass # Forward pass
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
output = model(input_tensor) output = model(input_tensor)
self.assertEqual(output.shape, (5, 40)) self.assertEqual(output.shape, (5, 40))
def test_state_dict_quantized_preserved(self): def test_state_dict_quantized_preserved(self):
"""Test that quantized weights are preserved in state_dict()""" """Test that quantized weights are preserved in state_dict()"""
# Configure mixed precision # Configure mixed precision
@ -126,7 +126,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
} }
} }
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create and load model # Create and load model
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
state_dict1 = { state_dict1 = {
@ -138,22 +138,22 @@ class TestMixedPrecisionOps(unittest.TestCase):
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
"layer3.bias": torch.randn(40, dtype=torch.bfloat16), "layer3.bias": torch.randn(40, dtype=torch.bfloat16),
} }
model = SimpleModel(operations=ops.MixedPrecisionOps) model = SimpleModel(operations=ops.MixedPrecisionOps)
model.load_state_dict(state_dict1, strict=False) model.load_state_dict(state_dict1, strict=False)
# Save state dict # Save state dict
state_dict2 = model.state_dict() state_dict2 = model.state_dict()
# Verify layer1.weight is a QuantizedTensor with scale preserved # Verify layer1.weight is a QuantizedTensor with scale preserved
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor) self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0) self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout) self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout)
# Verify non-quantized layers are standard tensors # Verify non-quantized layers are standard tensors
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor) self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)
self.assertNotIsInstance(state_dict2["layer3.weight"], QuantizedTensor) self.assertNotIsInstance(state_dict2["layer3.weight"], QuantizedTensor)
def test_weight_function_compatibility(self): def test_weight_function_compatibility(self):
"""Test that weight_function (LoRA) works with quantized layers""" """Test that weight_function (LoRA) works with quantized layers"""
# Configure FP8 quantization # Configure FP8 quantization
@ -164,7 +164,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
} }
} }
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create and load model # Create and load model
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
state_dict = { state_dict = {
@ -176,24 +176,24 @@ class TestMixedPrecisionOps(unittest.TestCase):
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
"layer3.bias": torch.randn(40, dtype=torch.bfloat16), "layer3.bias": torch.randn(40, dtype=torch.bfloat16),
} }
model = SimpleModel(operations=ops.MixedPrecisionOps) model = SimpleModel(operations=ops.MixedPrecisionOps)
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
# Add a weight function (simulating LoRA) # Add a weight function (simulating LoRA)
# This should trigger dequantization during forward pass # This should trigger dequantization during forward pass
def apply_lora(weight): def apply_lora(weight):
lora_delta = torch.randn_like(weight) * 0.01 lora_delta = torch.randn_like(weight) * 0.01
return weight + lora_delta return weight + lora_delta
model.layer1.weight_function.append(apply_lora) model.layer1.weight_function.append(apply_lora)
# Forward pass should work with LoRA (triggers weight_function path) # Forward pass should work with LoRA (triggers weight_function path)
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
output = model(input_tensor) output = model(input_tensor)
self.assertEqual(output.shape, (5, 40)) self.assertEqual(output.shape, (5, 40))
def test_error_handling_unknown_format(self): def test_error_handling_unknown_format(self):
"""Test that unknown formats raise error""" """Test that unknown formats raise error"""
# Configure with unknown format # Configure with unknown format
@ -204,7 +204,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
} }
} }
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create state dict # Create state dict
state_dict = { state_dict = {
"layer1.weight": torch.randn(20, 10, dtype=torch.bfloat16), "layer1.weight": torch.randn(20, 10, dtype=torch.bfloat16),
@ -214,7 +214,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
"layer3.bias": torch.randn(40, dtype=torch.bfloat16), "layer3.bias": torch.randn(40, dtype=torch.bfloat16),
} }
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS # Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
model = SimpleModel(operations=ops.MixedPrecisionOps) model = SimpleModel(operations=ops.MixedPrecisionOps)
with self.assertRaises(KeyError): with self.assertRaises(KeyError):

View File

@ -11,51 +11,51 @@ from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
class TestQuantizedTensor(unittest.TestCase): class TestQuantizedTensor(unittest.TestCase):
"""Test the QuantizedTensor subclass with FP8 layout""" """Test the QuantizedTensor subclass with FP8 layout"""
def test_creation(self): def test_creation(self):
"""Test creating a QuantizedTensor with TensorCoreFP8Layout""" """Test creating a QuantizedTensor with TensorCoreFP8Layout"""
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(2.0) scale = torch.tensor(2.0)
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16} layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
self.assertIsInstance(qt, QuantizedTensor) self.assertIsInstance(qt, QuantizedTensor)
self.assertEqual(qt.shape, (256, 128)) self.assertEqual(qt.shape, (256, 128))
self.assertEqual(qt.dtype, torch.float8_e4m3fn) self.assertEqual(qt.dtype, torch.float8_e4m3fn)
self.assertEqual(qt._layout_params['scale'], scale) self.assertEqual(qt._layout_params['scale'], scale)
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16) self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
self.assertEqual(qt._layout_type, TensorCoreFP8Layout) self.assertEqual(qt._layout_type, TensorCoreFP8Layout)
def test_dequantize(self): def test_dequantize(self):
"""Test explicit dequantization""" """Test explicit dequantization"""
fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(3.0) scale = torch.tensor(3.0)
layout_params = {'scale': scale, 'orig_dtype': torch.float32} layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
dequantized = qt.dequantize() dequantized = qt.dequantize()
self.assertEqual(dequantized.dtype, torch.float32) self.assertEqual(dequantized.dtype, torch.float32)
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))
def test_from_float(self): def test_from_float(self):
"""Test creating QuantizedTensor from float tensor""" """Test creating QuantizedTensor from float tensor"""
float_tensor = torch.randn(64, 32, dtype=torch.float32) float_tensor = torch.randn(64, 32, dtype=torch.float32)
scale = torch.tensor(1.5) scale = torch.tensor(1.5)
qt = QuantizedTensor.from_float( qt = QuantizedTensor.from_float(
float_tensor, float_tensor,
TensorCoreFP8Layout, TensorCoreFP8Layout,
scale=scale, scale=scale,
fp8_dtype=torch.float8_e4m3fn fp8_dtype=torch.float8_e4m3fn
) )
self.assertIsInstance(qt, QuantizedTensor) self.assertIsInstance(qt, QuantizedTensor)
self.assertEqual(qt.dtype, torch.float8_e4m3fn) self.assertEqual(qt.dtype, torch.float8_e4m3fn)
self.assertEqual(qt.shape, (64, 32)) self.assertEqual(qt.shape, (64, 32))
# Verify dequantization gives approximately original values # Verify dequantization gives approximately original values
dequantized = qt.dequantize() dequantized = qt.dequantize()
mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean() mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean()
@ -64,48 +64,48 @@ class TestQuantizedTensor(unittest.TestCase):
class TestGenericUtilities(unittest.TestCase): class TestGenericUtilities(unittest.TestCase):
"""Test generic utility operations""" """Test generic utility operations"""
def test_detach(self): def test_detach(self):
"""Test detach operation on quantized tensor""" """Test detach operation on quantized tensor"""
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(1.5) scale = torch.tensor(1.5)
layout_params = {'scale': scale, 'orig_dtype': torch.float32} layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
# Detach should return a new QuantizedTensor # Detach should return a new QuantizedTensor
qt_detached = qt.detach() qt_detached = qt.detach()
self.assertIsInstance(qt_detached, QuantizedTensor) self.assertIsInstance(qt_detached, QuantizedTensor)
self.assertEqual(qt_detached.shape, qt.shape) self.assertEqual(qt_detached.shape, qt.shape)
self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout) self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout)
def test_clone(self): def test_clone(self):
"""Test clone operation on quantized tensor""" """Test clone operation on quantized tensor"""
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(1.5) scale = torch.tensor(1.5)
layout_params = {'scale': scale, 'orig_dtype': torch.float32} layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
# Clone should return a new QuantizedTensor # Clone should return a new QuantizedTensor
qt_cloned = qt.clone() qt_cloned = qt.clone()
self.assertIsInstance(qt_cloned, QuantizedTensor) self.assertIsInstance(qt_cloned, QuantizedTensor)
self.assertEqual(qt_cloned.shape, qt.shape) self.assertEqual(qt_cloned.shape, qt.shape)
self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout) self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout)
# Verify it's a deep copy # Verify it's a deep copy
self.assertIsNot(qt_cloned._qdata, qt._qdata) self.assertIsNot(qt_cloned._qdata, qt._qdata)
def test_to_device(self): def test_to_device(self):
"""Test device transfer""" """Test device transfer"""
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(1.5) scale = torch.tensor(1.5)
layout_params = {'scale': scale, 'orig_dtype': torch.float32} layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
# Moving to same device should work (CPU to CPU) # Moving to same device should work (CPU to CPU)
qt_cpu = qt.to('cpu') qt_cpu = qt.to('cpu')
self.assertIsInstance(qt_cpu, QuantizedTensor) self.assertIsInstance(qt_cpu, QuantizedTensor)
self.assertEqual(qt_cpu.device.type, 'cpu') self.assertEqual(qt_cpu.device.type, 'cpu')
self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu') self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu')
@ -113,64 +113,63 @@ class TestGenericUtilities(unittest.TestCase):
class TestTensorCoreFP8Layout(unittest.TestCase): class TestTensorCoreFP8Layout(unittest.TestCase):
"""Test the TensorCoreFP8Layout implementation""" """Test the TensorCoreFP8Layout implementation"""
def test_quantize(self): def test_quantize(self):
"""Test quantization method""" """Test quantization method"""
float_tensor = torch.randn(32, 64, dtype=torch.float32) float_tensor = torch.randn(32, 64, dtype=torch.float32)
scale = torch.tensor(1.5) scale = torch.tensor(1.5)
qdata, layout_params = TensorCoreFP8Layout.quantize( qdata, layout_params = TensorCoreFP8Layout.quantize(
float_tensor, float_tensor,
scale=scale, scale=scale,
fp8_dtype=torch.float8_e4m3fn fp8_dtype=torch.float8_e4m3fn
) )
self.assertEqual(qdata.dtype, torch.float8_e4m3fn) self.assertEqual(qdata.dtype, torch.float8_e4m3fn)
self.assertEqual(qdata.shape, float_tensor.shape) self.assertEqual(qdata.shape, float_tensor.shape)
self.assertIn('scale', layout_params) self.assertIn('scale', layout_params)
self.assertIn('orig_dtype', layout_params) self.assertIn('orig_dtype', layout_params)
self.assertEqual(layout_params['orig_dtype'], torch.float32) self.assertEqual(layout_params['orig_dtype'], torch.float32)
def test_dequantize(self): def test_dequantize(self):
"""Test dequantization method""" """Test dequantization method"""
float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0 float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0
scale = torch.tensor(1.0) scale = torch.tensor(1.0)
qdata, layout_params = TensorCoreFP8Layout.quantize( qdata, layout_params = TensorCoreFP8Layout.quantize(
float_tensor, float_tensor,
scale=scale, scale=scale,
fp8_dtype=torch.float8_e4m3fn fp8_dtype=torch.float8_e4m3fn
) )
dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params) dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params)
# Should approximately match original # Should approximately match original
self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1)) self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1))
class TestFallbackMechanism(unittest.TestCase): class TestFallbackMechanism(unittest.TestCase):
"""Test fallback for unsupported operations""" """Test fallback for unsupported operations"""
def test_unsupported_op_dequantizes(self): def test_unsupported_op_dequantizes(self):
"""Test that unsupported operations fall back to dequantization""" """Test that unsupported operations fall back to dequantization"""
# Set seed for reproducibility # Set seed for reproducibility
torch.manual_seed(42) torch.manual_seed(42)
# Create quantized tensor # Create quantized tensor
a_fp32 = torch.randn(10, 20, dtype=torch.float32) a_fp32 = torch.randn(10, 20, dtype=torch.float32)
scale = torch.tensor(1.0) scale = torch.tensor(1.0)
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
a_q = QuantizedTensor.from_float( a_q = QuantizedTensor.from_float(
a_fp32, a_fp32,
TensorCoreFP8Layout, TensorCoreFP8Layout,
scale=scale, scale=scale,
fp8_dtype=torch.float8_e4m3fn fp8_dtype=torch.float8_e4m3fn
) )
# Call an operation that doesn't have a registered handler # Call an operation that doesn't have a registered handler
# For example, torch.abs # For example, torch.abs
result = torch.abs(a_q) result = torch.abs(a_q)
# Should work via fallback (dequantize → abs → return) # Should work via fallback (dequantize → abs → return)
self.assertNotIsInstance(result, QuantizedTensor) self.assertNotIsInstance(result, QuantizedTensor)
expected = torch.abs(a_fp32) expected = torch.abs(a_fp32)