mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 00:12:33 +08:00
ruff lint
This commit is contained in:
parent
a7216e18e5
commit
2a8b826442
@ -326,7 +326,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
if self.model_config.scaled_fp8 is not None:
|
if self.model_config.scaled_fp8 is not None:
|
||||||
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
|
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
|
||||||
|
|
||||||
# Save mixed precision metadata
|
# Save mixed precision metadata
|
||||||
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
|
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
|
||||||
metadata = {
|
metadata = {
|
||||||
|
|||||||
@ -16,7 +16,7 @@ def detect_layer_quantization(metadata):
|
|||||||
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
|
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
|
||||||
return quant_metadata["layers"]
|
return quant_metadata["layers"]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid quantization metadata format")
|
raise ValueError("Invalid quantization metadata format")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
20
comfy/ops.py
20
comfy/ops.py
@ -325,7 +325,7 @@ class manual_cast(disable_weight_init):
|
|||||||
|
|
||||||
def fp8_linear(self, input):
|
def fp8_linear(self, input):
|
||||||
"""
|
"""
|
||||||
Legacy FP8 linear function for backward compatibility.
|
Legacy FP8 linear function for backward compatibility.
|
||||||
Uses QuantizedTensor subclass for dispatch.
|
Uses QuantizedTensor subclass for dispatch.
|
||||||
"""
|
"""
|
||||||
dtype = self.weight.dtype
|
dtype = self.weight.dtype
|
||||||
@ -339,7 +339,7 @@ def fp8_linear(self, input):
|
|||||||
|
|
||||||
input_shape = input.shape
|
input_shape = input.shape
|
||||||
input_dtype = input.dtype
|
input_dtype = input.dtype
|
||||||
|
|
||||||
if len(input.shape) == 3:
|
if len(input.shape) == 3:
|
||||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
|
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
|
||||||
|
|
||||||
@ -354,14 +354,14 @@ def fp8_linear(self, input):
|
|||||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
else:
|
else:
|
||||||
scale_input = scale_input.to(input.device)
|
scale_input = scale_input.to(input.device)
|
||||||
|
|
||||||
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||||
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||||
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
||||||
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
|
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
|
||||||
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, fp8_dtype=dtype)
|
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, fp8_dtype=dtype)
|
||||||
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
||||||
|
|
||||||
if tensor_2d:
|
if tensor_2d:
|
||||||
return o.reshape(input_shape[0], -1)
|
return o.reshape(input_shape[0], -1)
|
||||||
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
||||||
@ -503,8 +503,8 @@ class MixedPrecisionOps(disable_weight_init):
|
|||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||||
strict, missing_keys, unexpected_keys, error_msgs):
|
strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
|
|
||||||
device = self.factory_kwargs["device"]
|
device = self.factory_kwargs["device"]
|
||||||
@ -520,10 +520,10 @@ class MixedPrecisionOps(disable_weight_init):
|
|||||||
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
|
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
|
||||||
if quant_format is None:
|
if quant_format is None:
|
||||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||||
|
|
||||||
mixin = QUANT_FORMAT_MIXINS[quant_format]
|
mixin = QUANT_FORMAT_MIXINS[quant_format]
|
||||||
self.layout_type = mixin["layout_type"]
|
self.layout_type = mixin["layout_type"]
|
||||||
|
|
||||||
layout_params = {
|
layout_params = {
|
||||||
'scale': state_dict.pop(f"{prefix}weight_scale", None),
|
'scale': state_dict.pop(f"{prefix}weight_scale", None),
|
||||||
'orig_dtype': MixedPrecisionOps._compute_dtype
|
'orig_dtype': MixedPrecisionOps._compute_dtype
|
||||||
@ -558,7 +558,7 @@ class MixedPrecisionOps(disable_weight_init):
|
|||||||
not isinstance(input, QuantizedTensor)):
|
not isinstance(input, QuantizedTensor)):
|
||||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype)
|
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype)
|
||||||
return self._forward(input, self.weight, self.bias)
|
return self._forward(input, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
|
||||||
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
|
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
|
||||||
@ -566,7 +566,7 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_
|
|||||||
MixedPrecisionOps._compute_dtype = compute_dtype
|
MixedPrecisionOps._compute_dtype = compute_dtype
|
||||||
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
|
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
|
||||||
return MixedPrecisionOps
|
return MixedPrecisionOps
|
||||||
|
|
||||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||||
if scaled_fp8 is not None:
|
if scaled_fp8 is not None:
|
||||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
||||||
|
|||||||
@ -31,7 +31,7 @@ def register_generic_util(torch_op):
|
|||||||
Decorator to register a generic utility that works for all layouts.
|
Decorator to register a generic utility that works for all layouts.
|
||||||
Args:
|
Args:
|
||||||
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
|
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@register_generic_util(torch.ops.aten.detach.default)
|
@register_generic_util(torch.ops.aten.detach.default)
|
||||||
def generic_detach(func, args, kwargs):
|
def generic_detach(func, args, kwargs):
|
||||||
@ -78,10 +78,10 @@ def _copy_layout_params(params):
|
|||||||
class QuantizedLayout:
|
class QuantizedLayout:
|
||||||
"""
|
"""
|
||||||
Base class for quantization layouts.
|
Base class for quantization layouts.
|
||||||
|
|
||||||
A layout encapsulates the format-specific logic for quantization/dequantization
|
A layout encapsulates the format-specific logic for quantization/dequantization
|
||||||
and provides a uniform interface for extracting raw tensors needed for computation.
|
and provides a uniform interface for extracting raw tensors needed for computation.
|
||||||
|
|
||||||
New quantization formats should subclass this and implement the required methods.
|
New quantization formats should subclass this and implement the required methods.
|
||||||
"""
|
"""
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -90,8 +90,8 @@ class QuantizedLayout:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def dequantize(qdata, **layout_params) -> torch.Tensor:
|
def dequantize(qdata, **layout_params) -> torch.Tensor:
|
||||||
raise NotImplementedError(f"TensorLayout must implement dequantize()")
|
raise NotImplementedError("TensorLayout must implement dequantize()")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
|
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
|
||||||
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
|
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
|
||||||
@ -100,45 +100,45 @@ class QuantizedLayout:
|
|||||||
class QuantizedTensor(torch.Tensor):
|
class QuantizedTensor(torch.Tensor):
|
||||||
"""
|
"""
|
||||||
Universal quantized tensor that works with any layout.
|
Universal quantized tensor that works with any layout.
|
||||||
|
|
||||||
This tensor subclass uses a pluggable layout system to support multiple
|
This tensor subclass uses a pluggable layout system to support multiple
|
||||||
quantization formats (FP8, INT4, INT8, etc.) without code duplication.
|
quantization formats (FP8, INT4, INT8, etc.) without code duplication.
|
||||||
|
|
||||||
The layout_type determines format-specific behavior, while common operations
|
The layout_type determines format-specific behavior, while common operations
|
||||||
(detach, clone, to) are handled generically.
|
(detach, clone, to) are handled generically.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
_qdata: The quantized tensor data
|
_qdata: The quantized tensor data
|
||||||
_layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
_layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
||||||
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
|
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __new__(cls, qdata, layout_type, layout_params):
|
def __new__(cls, qdata, layout_type, layout_params):
|
||||||
"""
|
"""
|
||||||
Create a quantized tensor.
|
Create a quantized tensor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
qdata: The quantized data tensor
|
qdata: The quantized data tensor
|
||||||
layout_type: Layout class (subclass of QuantizedLayout)
|
layout_type: Layout class (subclass of QuantizedLayout)
|
||||||
layout_params: Dict with layout-specific parameters
|
layout_params: Dict with layout-specific parameters
|
||||||
"""
|
"""
|
||||||
return torch.Tensor._make_subclass(cls, qdata, require_grad=False)
|
return torch.Tensor._make_subclass(cls, qdata, require_grad=False)
|
||||||
|
|
||||||
def __init__(self, qdata, layout_type, layout_params):
|
def __init__(self, qdata, layout_type, layout_params):
|
||||||
self._qdata = qdata.contiguous()
|
self._qdata = qdata.contiguous()
|
||||||
self._layout_type = layout_type
|
self._layout_type = layout_type
|
||||||
self._layout_params = layout_params
|
self._layout_params = layout_params
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
layout_name = self._layout_type.__name__
|
layout_name = self._layout_type.__name__
|
||||||
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
|
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
|
||||||
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
|
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layout_type(self):
|
def layout_type(self):
|
||||||
return self._layout_type
|
return self._layout_type
|
||||||
|
|
||||||
def __tensor_flatten__(self):
|
def __tensor_flatten__(self):
|
||||||
"""
|
"""
|
||||||
Tensor flattening protocol for proper device movement.
|
Tensor flattening protocol for proper device movement.
|
||||||
@ -147,7 +147,7 @@ class QuantizedTensor(torch.Tensor):
|
|||||||
ctx = {
|
ctx = {
|
||||||
"layout_type": self._layout_type,
|
"layout_type": self._layout_type,
|
||||||
}
|
}
|
||||||
|
|
||||||
tensor_params = {}
|
tensor_params = {}
|
||||||
non_tensor_params = {}
|
non_tensor_params = {}
|
||||||
for k, v in self._layout_params.items():
|
for k, v in self._layout_params.items():
|
||||||
@ -155,17 +155,17 @@ class QuantizedTensor(torch.Tensor):
|
|||||||
tensor_params[k] = v
|
tensor_params[k] = v
|
||||||
else:
|
else:
|
||||||
non_tensor_params[k] = v
|
non_tensor_params[k] = v
|
||||||
|
|
||||||
ctx["tensor_param_keys"] = list(tensor_params.keys())
|
ctx["tensor_param_keys"] = list(tensor_params.keys())
|
||||||
ctx["non_tensor_params"] = non_tensor_params
|
ctx["non_tensor_params"] = non_tensor_params
|
||||||
|
|
||||||
for k, v in tensor_params.items():
|
for k, v in tensor_params.items():
|
||||||
attr_name = f"_layout_param_{k}"
|
attr_name = f"_layout_param_{k}"
|
||||||
object.__setattr__(self, attr_name, v)
|
object.__setattr__(self, attr_name, v)
|
||||||
inner_tensors.append(attr_name)
|
inner_tensors.append(attr_name)
|
||||||
|
|
||||||
return inner_tensors, ctx
|
return inner_tensors, ctx
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
|
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
|
||||||
"""
|
"""
|
||||||
@ -174,41 +174,41 @@ class QuantizedTensor(torch.Tensor):
|
|||||||
"""
|
"""
|
||||||
layout_type = ctx["layout_type"]
|
layout_type = ctx["layout_type"]
|
||||||
layout_params = dict(ctx["non_tensor_params"])
|
layout_params = dict(ctx["non_tensor_params"])
|
||||||
|
|
||||||
for key in ctx["tensor_param_keys"]:
|
for key in ctx["tensor_param_keys"]:
|
||||||
attr_name = f"_layout_param_{key}"
|
attr_name = f"_layout_param_{key}"
|
||||||
layout_params[key] = inner_tensors[attr_name]
|
layout_params[key] = inner_tensors[attr_name]
|
||||||
|
|
||||||
return QuantizedTensor(inner_tensors["_q_data"], layout_type, layout_params)
|
return QuantizedTensor(inner_tensors["_q_data"], layout_type, layout_params)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
|
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
|
||||||
qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs)
|
qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs)
|
||||||
return cls(qdata, layout_type, layout_params)
|
return cls(qdata, layout_type, layout_params)
|
||||||
|
|
||||||
def dequantize(self) -> torch.Tensor:
|
def dequantize(self) -> torch.Tensor:
|
||||||
return self._layout_type.dequantize(self._qdata, **self._layout_params)
|
return self._layout_type.dequantize(self._qdata, **self._layout_params)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||||
kwargs = kwargs or {}
|
kwargs = kwargs or {}
|
||||||
|
|
||||||
# Step 1: Check generic utilities first (detach, clone, to, etc.)
|
# Step 1: Check generic utilities first (detach, clone, to, etc.)
|
||||||
if func in _GENERIC_UTILS:
|
if func in _GENERIC_UTILS:
|
||||||
return _GENERIC_UTILS[func](func, args, kwargs)
|
return _GENERIC_UTILS[func](func, args, kwargs)
|
||||||
|
|
||||||
# Step 2: Check layout-specific handlers (linear, matmul, etc.)
|
# Step 2: Check layout-specific handlers (linear, matmul, etc.)
|
||||||
layout_type = _get_layout_from_args(args)
|
layout_type = _get_layout_from_args(args)
|
||||||
if layout_type and func in _LAYOUT_REGISTRY:
|
if layout_type and func in _LAYOUT_REGISTRY:
|
||||||
handler = _LAYOUT_REGISTRY[func].get(layout_type)
|
handler = _LAYOUT_REGISTRY[func].get(layout_type)
|
||||||
if handler:
|
if handler:
|
||||||
return handler(func, args, kwargs)
|
return handler(func, args, kwargs)
|
||||||
|
|
||||||
# Step 3: Fallback to dequantization
|
# Step 3: Fallback to dequantization
|
||||||
if isinstance(args[0] if args else None, QuantizedTensor):
|
if isinstance(args[0] if args else None, QuantizedTensor):
|
||||||
logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
|
logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
|
||||||
return cls._dequant_and_fallback(func, args, kwargs)
|
return cls._dequant_and_fallback(func, args, kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _dequant_and_fallback(cls, func, args, kwargs):
|
def _dequant_and_fallback(cls, func, args, kwargs):
|
||||||
def dequant_arg(arg):
|
def dequant_arg(arg):
|
||||||
@ -217,7 +217,7 @@ class QuantizedTensor(torch.Tensor):
|
|||||||
elif isinstance(arg, (list, tuple)):
|
elif isinstance(arg, (list, tuple)):
|
||||||
return type(arg)(dequant_arg(a) for a in arg)
|
return type(arg)(dequant_arg(a) for a in arg)
|
||||||
return arg
|
return arg
|
||||||
|
|
||||||
new_args = dequant_arg(args)
|
new_args = dequant_arg(args)
|
||||||
new_kwargs = dequant_arg(kwargs)
|
new_kwargs = dequant_arg(kwargs)
|
||||||
return func(*new_args, **new_kwargs)
|
return func(*new_args, **new_kwargs)
|
||||||
@ -239,13 +239,13 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=
|
|||||||
f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
|
f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
|
||||||
f"but not supported for quantized tensors. Ignoring dtype."
|
f"but not supported for quantized tensors. Ignoring dtype."
|
||||||
)
|
)
|
||||||
|
|
||||||
if target_layout is not None and target_layout != torch.strided:
|
if target_layout is not None and target_layout != torch.strided:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"QuantizedTensor: layout change requested to {target_layout}, "
|
f"QuantizedTensor: layout change requested to {target_layout}, "
|
||||||
f"but not supported. Ignoring layout."
|
f"but not supported. Ignoring layout."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle device transfer
|
# Handle device transfer
|
||||||
current_device = qt._qdata.device
|
current_device = qt._qdata.device
|
||||||
if target_device is not None:
|
if target_device is not None:
|
||||||
@ -254,7 +254,7 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=
|
|||||||
target_device = torch.device(target_device)
|
target_device = torch.device(target_device)
|
||||||
if isinstance(current_device, str):
|
if isinstance(current_device, str):
|
||||||
current_device = torch.device(current_device)
|
current_device = torch.device(current_device)
|
||||||
|
|
||||||
if target_device != current_device:
|
if target_device != current_device:
|
||||||
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
|
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
|
||||||
new_q_data = qt._qdata.to(device=target_device)
|
new_q_data = qt._qdata.to(device=target_device)
|
||||||
@ -262,7 +262,7 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=
|
|||||||
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
|
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
|
||||||
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
|
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
|
||||||
return new_qt
|
return new_qt
|
||||||
|
|
||||||
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
|
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
|
||||||
return qt
|
return qt
|
||||||
|
|
||||||
@ -318,7 +318,7 @@ def generic_to_dtype_layout(func, args, kwargs):
|
|||||||
def generic_copy_(func, args, kwargs):
|
def generic_copy_(func, args, kwargs):
|
||||||
qt_dest = args[0]
|
qt_dest = args[0]
|
||||||
src = args[1]
|
src = args[1]
|
||||||
|
|
||||||
if isinstance(qt_dest, QuantizedTensor):
|
if isinstance(qt_dest, QuantizedTensor):
|
||||||
if isinstance(src, QuantizedTensor):
|
if isinstance(src, QuantizedTensor):
|
||||||
# Copy from another quantized tensor
|
# Copy from another quantized tensor
|
||||||
@ -383,15 +383,15 @@ def fp8_linear(func, args, kwargs):
|
|||||||
input_tensor = args[0]
|
input_tensor = args[0]
|
||||||
weight = args[1]
|
weight = args[1]
|
||||||
bias = args[2] if len(args) > 2 else None
|
bias = args[2] if len(args) > 2 else None
|
||||||
|
|
||||||
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||||
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||||
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
||||||
|
|
||||||
out_dtype = kwargs.get("out_dtype")
|
out_dtype = kwargs.get("out_dtype")
|
||||||
if out_dtype is None:
|
if out_dtype is None:
|
||||||
out_dtype = input_tensor._layout_params['orig_dtype']
|
out_dtype = input_tensor._layout_params['orig_dtype']
|
||||||
|
|
||||||
weight_t = plain_weight.t()
|
weight_t = plain_weight.t()
|
||||||
|
|
||||||
tensor_2d = False
|
tensor_2d = False
|
||||||
@ -424,7 +424,7 @@ def fp8_linear(func, args, kwargs):
|
|||||||
return QuantizedTensor(output, TensorCoreFP8Layout, output_params)
|
return QuantizedTensor(output, TensorCoreFP8Layout, output_params)
|
||||||
else:
|
else:
|
||||||
return output
|
return output
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
|
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
|
||||||
|
|
||||||
|
|||||||
@ -16,7 +16,7 @@ class SimpleModel(torch.nn.Module):
|
|||||||
self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16)
|
self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16)
|
||||||
self.layer2 = operations.Linear(20, 30, device="cpu", dtype=torch.bfloat16)
|
self.layer2 = operations.Linear(20, 30, device="cpu", dtype=torch.bfloat16)
|
||||||
self.layer3 = operations.Linear(30, 40, device="cpu", dtype=torch.bfloat16)
|
self.layer3 = operations.Linear(30, 40, device="cpu", dtype=torch.bfloat16)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.layer1(x)
|
x = self.layer1(x)
|
||||||
x = torch.nn.functional.relu(x)
|
x = torch.nn.functional.relu(x)
|
||||||
@ -32,10 +32,10 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
"""Test that model with no quantization works normally"""
|
"""Test that model with no quantization works normally"""
|
||||||
# Configure no quantization
|
# Configure no quantization
|
||||||
ops.MixedPrecisionOps._layer_quant_config = {}
|
ops.MixedPrecisionOps._layer_quant_config = {}
|
||||||
|
|
||||||
# Create model
|
# Create model
|
||||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||||
|
|
||||||
# Initialize weights manually
|
# Initialize weights manually
|
||||||
model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16))
|
model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16))
|
||||||
model.layer1.bias = torch.nn.Parameter(torch.randn(20, dtype=torch.bfloat16))
|
model.layer1.bias = torch.nn.Parameter(torch.randn(20, dtype=torch.bfloat16))
|
||||||
@ -43,19 +43,19 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
model.layer2.bias = torch.nn.Parameter(torch.randn(30, dtype=torch.bfloat16))
|
model.layer2.bias = torch.nn.Parameter(torch.randn(30, dtype=torch.bfloat16))
|
||||||
model.layer3.weight = torch.nn.Parameter(torch.randn(40, 30, dtype=torch.bfloat16))
|
model.layer3.weight = torch.nn.Parameter(torch.randn(40, 30, dtype=torch.bfloat16))
|
||||||
model.layer3.bias = torch.nn.Parameter(torch.randn(40, dtype=torch.bfloat16))
|
model.layer3.bias = torch.nn.Parameter(torch.randn(40, dtype=torch.bfloat16))
|
||||||
|
|
||||||
# Initialize weight_function and bias_function
|
# Initialize weight_function and bias_function
|
||||||
for layer in [model.layer1, model.layer2, model.layer3]:
|
for layer in [model.layer1, model.layer2, model.layer3]:
|
||||||
layer.weight_function = []
|
layer.weight_function = []
|
||||||
layer.bias_function = []
|
layer.bias_function = []
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||||
output = model(input_tensor)
|
output = model(input_tensor)
|
||||||
|
|
||||||
self.assertEqual(output.shape, (5, 40))
|
self.assertEqual(output.shape, (5, 40))
|
||||||
self.assertEqual(output.dtype, torch.bfloat16)
|
self.assertEqual(output.dtype, torch.bfloat16)
|
||||||
|
|
||||||
def test_mixed_precision_load(self):
|
def test_mixed_precision_load(self):
|
||||||
"""Test loading a mixed precision model from state dict"""
|
"""Test loading a mixed precision model from state dict"""
|
||||||
# Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard
|
# Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard
|
||||||
@ -70,52 +70,52 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||||
|
|
||||||
# Create state dict with mixed precision
|
# Create state dict with mixed precision
|
||||||
fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e4m3fn)
|
fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
state_dict = {
|
state_dict = {
|
||||||
# Layer 1: FP8 E4M3FN
|
# Layer 1: FP8 E4M3FN
|
||||||
"layer1.weight": fp8_weight1,
|
"layer1.weight": fp8_weight1,
|
||||||
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
|
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
|
||||||
"layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32),
|
"layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32),
|
||||||
|
|
||||||
# Layer 2: Standard BF16
|
# Layer 2: Standard BF16
|
||||||
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
|
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
|
||||||
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
|
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
|
||||||
|
|
||||||
# Layer 3: FP8 E4M3FN
|
# Layer 3: FP8 E4M3FN
|
||||||
"layer3.weight": fp8_weight3,
|
"layer3.weight": fp8_weight3,
|
||||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||||
"layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32),
|
"layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create model and load state dict (strict=False because custom loading pops keys)
|
# Create model and load state dict (strict=False because custom loading pops keys)
|
||||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
# Verify weights are wrapped in QuantizedTensor
|
# Verify weights are wrapped in QuantizedTensor
|
||||||
self.assertIsInstance(model.layer1.weight, QuantizedTensor)
|
self.assertIsInstance(model.layer1.weight, QuantizedTensor)
|
||||||
self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout)
|
self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout)
|
||||||
|
|
||||||
# Layer 2 should NOT be quantized
|
# Layer 2 should NOT be quantized
|
||||||
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
|
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
|
||||||
|
|
||||||
# Layer 3 should be quantized
|
# Layer 3 should be quantized
|
||||||
self.assertIsInstance(model.layer3.weight, QuantizedTensor)
|
self.assertIsInstance(model.layer3.weight, QuantizedTensor)
|
||||||
self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout)
|
self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout)
|
||||||
|
|
||||||
# Verify scales were loaded
|
# Verify scales were loaded
|
||||||
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
|
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
|
||||||
self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5)
|
self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5)
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||||
output = model(input_tensor)
|
output = model(input_tensor)
|
||||||
|
|
||||||
self.assertEqual(output.shape, (5, 40))
|
self.assertEqual(output.shape, (5, 40))
|
||||||
|
|
||||||
def test_state_dict_quantized_preserved(self):
|
def test_state_dict_quantized_preserved(self):
|
||||||
"""Test that quantized weights are preserved in state_dict()"""
|
"""Test that quantized weights are preserved in state_dict()"""
|
||||||
# Configure mixed precision
|
# Configure mixed precision
|
||||||
@ -126,7 +126,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||||
|
|
||||||
# Create and load model
|
# Create and load model
|
||||||
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
state_dict1 = {
|
state_dict1 = {
|
||||||
@ -138,22 +138,22 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
|
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
|
||||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||||
}
|
}
|
||||||
|
|
||||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||||
model.load_state_dict(state_dict1, strict=False)
|
model.load_state_dict(state_dict1, strict=False)
|
||||||
|
|
||||||
# Save state dict
|
# Save state dict
|
||||||
state_dict2 = model.state_dict()
|
state_dict2 = model.state_dict()
|
||||||
|
|
||||||
# Verify layer1.weight is a QuantizedTensor with scale preserved
|
# Verify layer1.weight is a QuantizedTensor with scale preserved
|
||||||
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
|
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
|
||||||
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
|
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
|
||||||
self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout)
|
self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout)
|
||||||
|
|
||||||
# Verify non-quantized layers are standard tensors
|
# Verify non-quantized layers are standard tensors
|
||||||
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)
|
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)
|
||||||
self.assertNotIsInstance(state_dict2["layer3.weight"], QuantizedTensor)
|
self.assertNotIsInstance(state_dict2["layer3.weight"], QuantizedTensor)
|
||||||
|
|
||||||
def test_weight_function_compatibility(self):
|
def test_weight_function_compatibility(self):
|
||||||
"""Test that weight_function (LoRA) works with quantized layers"""
|
"""Test that weight_function (LoRA) works with quantized layers"""
|
||||||
# Configure FP8 quantization
|
# Configure FP8 quantization
|
||||||
@ -164,7 +164,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||||
|
|
||||||
# Create and load model
|
# Create and load model
|
||||||
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
state_dict = {
|
state_dict = {
|
||||||
@ -176,24 +176,24 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
|
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
|
||||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||||
}
|
}
|
||||||
|
|
||||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
# Add a weight function (simulating LoRA)
|
# Add a weight function (simulating LoRA)
|
||||||
# This should trigger dequantization during forward pass
|
# This should trigger dequantization during forward pass
|
||||||
def apply_lora(weight):
|
def apply_lora(weight):
|
||||||
lora_delta = torch.randn_like(weight) * 0.01
|
lora_delta = torch.randn_like(weight) * 0.01
|
||||||
return weight + lora_delta
|
return weight + lora_delta
|
||||||
|
|
||||||
model.layer1.weight_function.append(apply_lora)
|
model.layer1.weight_function.append(apply_lora)
|
||||||
|
|
||||||
# Forward pass should work with LoRA (triggers weight_function path)
|
# Forward pass should work with LoRA (triggers weight_function path)
|
||||||
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||||
output = model(input_tensor)
|
output = model(input_tensor)
|
||||||
|
|
||||||
self.assertEqual(output.shape, (5, 40))
|
self.assertEqual(output.shape, (5, 40))
|
||||||
|
|
||||||
def test_error_handling_unknown_format(self):
|
def test_error_handling_unknown_format(self):
|
||||||
"""Test that unknown formats raise error"""
|
"""Test that unknown formats raise error"""
|
||||||
# Configure with unknown format
|
# Configure with unknown format
|
||||||
@ -204,7 +204,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||||
|
|
||||||
# Create state dict
|
# Create state dict
|
||||||
state_dict = {
|
state_dict = {
|
||||||
"layer1.weight": torch.randn(20, 10, dtype=torch.bfloat16),
|
"layer1.weight": torch.randn(20, 10, dtype=torch.bfloat16),
|
||||||
@ -214,7 +214,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
|
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
|
||||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
|
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
|
||||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||||
with self.assertRaises(KeyError):
|
with self.assertRaises(KeyError):
|
||||||
|
|||||||
@ -11,51 +11,51 @@ from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
|||||||
|
|
||||||
class TestQuantizedTensor(unittest.TestCase):
|
class TestQuantizedTensor(unittest.TestCase):
|
||||||
"""Test the QuantizedTensor subclass with FP8 layout"""
|
"""Test the QuantizedTensor subclass with FP8 layout"""
|
||||||
|
|
||||||
def test_creation(self):
|
def test_creation(self):
|
||||||
"""Test creating a QuantizedTensor with TensorCoreFP8Layout"""
|
"""Test creating a QuantizedTensor with TensorCoreFP8Layout"""
|
||||||
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
|
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
scale = torch.tensor(2.0)
|
scale = torch.tensor(2.0)
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
|
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
|
||||||
|
|
||||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
||||||
|
|
||||||
self.assertIsInstance(qt, QuantizedTensor)
|
self.assertIsInstance(qt, QuantizedTensor)
|
||||||
self.assertEqual(qt.shape, (256, 128))
|
self.assertEqual(qt.shape, (256, 128))
|
||||||
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
||||||
self.assertEqual(qt._layout_params['scale'], scale)
|
self.assertEqual(qt._layout_params['scale'], scale)
|
||||||
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
|
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
|
||||||
self.assertEqual(qt._layout_type, TensorCoreFP8Layout)
|
self.assertEqual(qt._layout_type, TensorCoreFP8Layout)
|
||||||
|
|
||||||
def test_dequantize(self):
|
def test_dequantize(self):
|
||||||
"""Test explicit dequantization"""
|
"""Test explicit dequantization"""
|
||||||
|
|
||||||
fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
scale = torch.tensor(3.0)
|
scale = torch.tensor(3.0)
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||||
|
|
||||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
||||||
dequantized = qt.dequantize()
|
dequantized = qt.dequantize()
|
||||||
|
|
||||||
self.assertEqual(dequantized.dtype, torch.float32)
|
self.assertEqual(dequantized.dtype, torch.float32)
|
||||||
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))
|
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))
|
||||||
|
|
||||||
def test_from_float(self):
|
def test_from_float(self):
|
||||||
"""Test creating QuantizedTensor from float tensor"""
|
"""Test creating QuantizedTensor from float tensor"""
|
||||||
float_tensor = torch.randn(64, 32, dtype=torch.float32)
|
float_tensor = torch.randn(64, 32, dtype=torch.float32)
|
||||||
scale = torch.tensor(1.5)
|
scale = torch.tensor(1.5)
|
||||||
|
|
||||||
qt = QuantizedTensor.from_float(
|
qt = QuantizedTensor.from_float(
|
||||||
float_tensor,
|
float_tensor,
|
||||||
TensorCoreFP8Layout,
|
TensorCoreFP8Layout,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
fp8_dtype=torch.float8_e4m3fn
|
fp8_dtype=torch.float8_e4m3fn
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsInstance(qt, QuantizedTensor)
|
self.assertIsInstance(qt, QuantizedTensor)
|
||||||
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
||||||
self.assertEqual(qt.shape, (64, 32))
|
self.assertEqual(qt.shape, (64, 32))
|
||||||
|
|
||||||
# Verify dequantization gives approximately original values
|
# Verify dequantization gives approximately original values
|
||||||
dequantized = qt.dequantize()
|
dequantized = qt.dequantize()
|
||||||
mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean()
|
mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean()
|
||||||
@ -64,48 +64,48 @@ class TestQuantizedTensor(unittest.TestCase):
|
|||||||
|
|
||||||
class TestGenericUtilities(unittest.TestCase):
|
class TestGenericUtilities(unittest.TestCase):
|
||||||
"""Test generic utility operations"""
|
"""Test generic utility operations"""
|
||||||
|
|
||||||
def test_detach(self):
|
def test_detach(self):
|
||||||
"""Test detach operation on quantized tensor"""
|
"""Test detach operation on quantized tensor"""
|
||||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
scale = torch.tensor(1.5)
|
scale = torch.tensor(1.5)
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
||||||
|
|
||||||
# Detach should return a new QuantizedTensor
|
# Detach should return a new QuantizedTensor
|
||||||
qt_detached = qt.detach()
|
qt_detached = qt.detach()
|
||||||
|
|
||||||
self.assertIsInstance(qt_detached, QuantizedTensor)
|
self.assertIsInstance(qt_detached, QuantizedTensor)
|
||||||
self.assertEqual(qt_detached.shape, qt.shape)
|
self.assertEqual(qt_detached.shape, qt.shape)
|
||||||
self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout)
|
self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout)
|
||||||
|
|
||||||
def test_clone(self):
|
def test_clone(self):
|
||||||
"""Test clone operation on quantized tensor"""
|
"""Test clone operation on quantized tensor"""
|
||||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
scale = torch.tensor(1.5)
|
scale = torch.tensor(1.5)
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
||||||
|
|
||||||
# Clone should return a new QuantizedTensor
|
# Clone should return a new QuantizedTensor
|
||||||
qt_cloned = qt.clone()
|
qt_cloned = qt.clone()
|
||||||
|
|
||||||
self.assertIsInstance(qt_cloned, QuantizedTensor)
|
self.assertIsInstance(qt_cloned, QuantizedTensor)
|
||||||
self.assertEqual(qt_cloned.shape, qt.shape)
|
self.assertEqual(qt_cloned.shape, qt.shape)
|
||||||
self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout)
|
self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout)
|
||||||
|
|
||||||
# Verify it's a deep copy
|
# Verify it's a deep copy
|
||||||
self.assertIsNot(qt_cloned._qdata, qt._qdata)
|
self.assertIsNot(qt_cloned._qdata, qt._qdata)
|
||||||
|
|
||||||
def test_to_device(self):
|
def test_to_device(self):
|
||||||
"""Test device transfer"""
|
"""Test device transfer"""
|
||||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
scale = torch.tensor(1.5)
|
scale = torch.tensor(1.5)
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
||||||
|
|
||||||
# Moving to same device should work (CPU to CPU)
|
# Moving to same device should work (CPU to CPU)
|
||||||
qt_cpu = qt.to('cpu')
|
qt_cpu = qt.to('cpu')
|
||||||
|
|
||||||
self.assertIsInstance(qt_cpu, QuantizedTensor)
|
self.assertIsInstance(qt_cpu, QuantizedTensor)
|
||||||
self.assertEqual(qt_cpu.device.type, 'cpu')
|
self.assertEqual(qt_cpu.device.type, 'cpu')
|
||||||
self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu')
|
self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu')
|
||||||
@ -113,64 +113,63 @@ class TestGenericUtilities(unittest.TestCase):
|
|||||||
|
|
||||||
class TestTensorCoreFP8Layout(unittest.TestCase):
|
class TestTensorCoreFP8Layout(unittest.TestCase):
|
||||||
"""Test the TensorCoreFP8Layout implementation"""
|
"""Test the TensorCoreFP8Layout implementation"""
|
||||||
|
|
||||||
def test_quantize(self):
|
def test_quantize(self):
|
||||||
"""Test quantization method"""
|
"""Test quantization method"""
|
||||||
float_tensor = torch.randn(32, 64, dtype=torch.float32)
|
float_tensor = torch.randn(32, 64, dtype=torch.float32)
|
||||||
scale = torch.tensor(1.5)
|
scale = torch.tensor(1.5)
|
||||||
|
|
||||||
qdata, layout_params = TensorCoreFP8Layout.quantize(
|
qdata, layout_params = TensorCoreFP8Layout.quantize(
|
||||||
float_tensor,
|
float_tensor,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
fp8_dtype=torch.float8_e4m3fn
|
fp8_dtype=torch.float8_e4m3fn
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(qdata.dtype, torch.float8_e4m3fn)
|
self.assertEqual(qdata.dtype, torch.float8_e4m3fn)
|
||||||
self.assertEqual(qdata.shape, float_tensor.shape)
|
self.assertEqual(qdata.shape, float_tensor.shape)
|
||||||
self.assertIn('scale', layout_params)
|
self.assertIn('scale', layout_params)
|
||||||
self.assertIn('orig_dtype', layout_params)
|
self.assertIn('orig_dtype', layout_params)
|
||||||
self.assertEqual(layout_params['orig_dtype'], torch.float32)
|
self.assertEqual(layout_params['orig_dtype'], torch.float32)
|
||||||
|
|
||||||
def test_dequantize(self):
|
def test_dequantize(self):
|
||||||
"""Test dequantization method"""
|
"""Test dequantization method"""
|
||||||
float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0
|
float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0
|
||||||
scale = torch.tensor(1.0)
|
scale = torch.tensor(1.0)
|
||||||
|
|
||||||
qdata, layout_params = TensorCoreFP8Layout.quantize(
|
qdata, layout_params = TensorCoreFP8Layout.quantize(
|
||||||
float_tensor,
|
float_tensor,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
fp8_dtype=torch.float8_e4m3fn
|
fp8_dtype=torch.float8_e4m3fn
|
||||||
)
|
)
|
||||||
|
|
||||||
dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params)
|
dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params)
|
||||||
|
|
||||||
# Should approximately match original
|
# Should approximately match original
|
||||||
self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1))
|
self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1))
|
||||||
|
|
||||||
|
|
||||||
class TestFallbackMechanism(unittest.TestCase):
|
class TestFallbackMechanism(unittest.TestCase):
|
||||||
"""Test fallback for unsupported operations"""
|
"""Test fallback for unsupported operations"""
|
||||||
|
|
||||||
def test_unsupported_op_dequantizes(self):
|
def test_unsupported_op_dequantizes(self):
|
||||||
"""Test that unsupported operations fall back to dequantization"""
|
"""Test that unsupported operations fall back to dequantization"""
|
||||||
# Set seed for reproducibility
|
# Set seed for reproducibility
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
|
|
||||||
# Create quantized tensor
|
# Create quantized tensor
|
||||||
a_fp32 = torch.randn(10, 20, dtype=torch.float32)
|
a_fp32 = torch.randn(10, 20, dtype=torch.float32)
|
||||||
scale = torch.tensor(1.0)
|
scale = torch.tensor(1.0)
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
|
||||||
a_q = QuantizedTensor.from_float(
|
a_q = QuantizedTensor.from_float(
|
||||||
a_fp32,
|
a_fp32,
|
||||||
TensorCoreFP8Layout,
|
TensorCoreFP8Layout,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
fp8_dtype=torch.float8_e4m3fn
|
fp8_dtype=torch.float8_e4m3fn
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call an operation that doesn't have a registered handler
|
# Call an operation that doesn't have a registered handler
|
||||||
# For example, torch.abs
|
# For example, torch.abs
|
||||||
result = torch.abs(a_q)
|
result = torch.abs(a_q)
|
||||||
|
|
||||||
# Should work via fallback (dequantize → abs → return)
|
# Should work via fallback (dequantize → abs → return)
|
||||||
self.assertNotIsInstance(result, QuantizedTensor)
|
self.assertNotIsInstance(result, QuantizedTensor)
|
||||||
expected = torch.abs(a_fp32)
|
expected = torch.abs(a_fp32)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user