mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 00:12:33 +08:00
ruff lint
This commit is contained in:
parent
a7216e18e5
commit
2a8b826442
@ -326,7 +326,7 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
if self.model_config.scaled_fp8 is not None:
|
||||
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
|
||||
|
||||
|
||||
# Save mixed precision metadata
|
||||
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
|
||||
metadata = {
|
||||
|
||||
@ -16,7 +16,7 @@ def detect_layer_quantization(metadata):
|
||||
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
|
||||
return quant_metadata["layers"]
|
||||
else:
|
||||
raise ValueError(f"Invalid quantization metadata format")
|
||||
raise ValueError("Invalid quantization metadata format")
|
||||
return None
|
||||
|
||||
|
||||
|
||||
20
comfy/ops.py
20
comfy/ops.py
@ -325,7 +325,7 @@ class manual_cast(disable_weight_init):
|
||||
|
||||
def fp8_linear(self, input):
|
||||
"""
|
||||
Legacy FP8 linear function for backward compatibility.
|
||||
Legacy FP8 linear function for backward compatibility.
|
||||
Uses QuantizedTensor subclass for dispatch.
|
||||
"""
|
||||
dtype = self.weight.dtype
|
||||
@ -339,7 +339,7 @@ def fp8_linear(self, input):
|
||||
|
||||
input_shape = input.shape
|
||||
input_dtype = input.dtype
|
||||
|
||||
|
||||
if len(input.shape) == 3:
|
||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
|
||||
|
||||
@ -354,14 +354,14 @@ def fp8_linear(self, input):
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
else:
|
||||
scale_input = scale_input.to(input.device)
|
||||
|
||||
|
||||
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
||||
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
|
||||
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, fp8_dtype=dtype)
|
||||
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
||||
|
||||
|
||||
if tensor_2d:
|
||||
return o.reshape(input_shape[0], -1)
|
||||
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
||||
@ -503,8 +503,8 @@ class MixedPrecisionOps(disable_weight_init):
|
||||
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
|
||||
device = self.factory_kwargs["device"]
|
||||
@ -520,10 +520,10 @@ class MixedPrecisionOps(disable_weight_init):
|
||||
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
|
||||
if quant_format is None:
|
||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||
|
||||
|
||||
mixin = QUANT_FORMAT_MIXINS[quant_format]
|
||||
self.layout_type = mixin["layout_type"]
|
||||
|
||||
|
||||
layout_params = {
|
||||
'scale': state_dict.pop(f"{prefix}weight_scale", None),
|
||||
'orig_dtype': MixedPrecisionOps._compute_dtype
|
||||
@ -558,7 +558,7 @@ class MixedPrecisionOps(disable_weight_init):
|
||||
not isinstance(input, QuantizedTensor)):
|
||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype)
|
||||
return self._forward(input, self.weight, self.bias)
|
||||
|
||||
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
|
||||
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
|
||||
@ -566,7 +566,7 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_
|
||||
MixedPrecisionOps._compute_dtype = compute_dtype
|
||||
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
|
||||
return MixedPrecisionOps
|
||||
|
||||
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||
if scaled_fp8 is not None:
|
||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
||||
|
||||
@ -31,7 +31,7 @@ def register_generic_util(torch_op):
|
||||
Decorator to register a generic utility that works for all layouts.
|
||||
Args:
|
||||
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
|
||||
|
||||
|
||||
Example:
|
||||
@register_generic_util(torch.ops.aten.detach.default)
|
||||
def generic_detach(func, args, kwargs):
|
||||
@ -78,10 +78,10 @@ def _copy_layout_params(params):
|
||||
class QuantizedLayout:
|
||||
"""
|
||||
Base class for quantization layouts.
|
||||
|
||||
|
||||
A layout encapsulates the format-specific logic for quantization/dequantization
|
||||
and provides a uniform interface for extracting raw tensors needed for computation.
|
||||
|
||||
|
||||
New quantization formats should subclass this and implement the required methods.
|
||||
"""
|
||||
@classmethod
|
||||
@ -90,8 +90,8 @@ class QuantizedLayout:
|
||||
|
||||
@staticmethod
|
||||
def dequantize(qdata, **layout_params) -> torch.Tensor:
|
||||
raise NotImplementedError(f"TensorLayout must implement dequantize()")
|
||||
|
||||
raise NotImplementedError("TensorLayout must implement dequantize()")
|
||||
|
||||
@classmethod
|
||||
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
|
||||
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
|
||||
@ -100,45 +100,45 @@ class QuantizedLayout:
|
||||
class QuantizedTensor(torch.Tensor):
|
||||
"""
|
||||
Universal quantized tensor that works with any layout.
|
||||
|
||||
|
||||
This tensor subclass uses a pluggable layout system to support multiple
|
||||
quantization formats (FP8, INT4, INT8, etc.) without code duplication.
|
||||
|
||||
|
||||
The layout_type determines format-specific behavior, while common operations
|
||||
(detach, clone, to) are handled generically.
|
||||
|
||||
|
||||
Attributes:
|
||||
_qdata: The quantized tensor data
|
||||
_layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
||||
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
|
||||
"""
|
||||
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, qdata, layout_type, layout_params):
|
||||
"""
|
||||
Create a quantized tensor.
|
||||
|
||||
|
||||
Args:
|
||||
qdata: The quantized data tensor
|
||||
layout_type: Layout class (subclass of QuantizedLayout)
|
||||
layout_params: Dict with layout-specific parameters
|
||||
"""
|
||||
return torch.Tensor._make_subclass(cls, qdata, require_grad=False)
|
||||
|
||||
|
||||
def __init__(self, qdata, layout_type, layout_params):
|
||||
self._qdata = qdata.contiguous()
|
||||
self._layout_type = layout_type
|
||||
self._layout_params = layout_params
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
layout_name = self._layout_type.__name__
|
||||
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
|
||||
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
|
||||
|
||||
|
||||
@property
|
||||
def layout_type(self):
|
||||
return self._layout_type
|
||||
|
||||
|
||||
def __tensor_flatten__(self):
|
||||
"""
|
||||
Tensor flattening protocol for proper device movement.
|
||||
@ -147,7 +147,7 @@ class QuantizedTensor(torch.Tensor):
|
||||
ctx = {
|
||||
"layout_type": self._layout_type,
|
||||
}
|
||||
|
||||
|
||||
tensor_params = {}
|
||||
non_tensor_params = {}
|
||||
for k, v in self._layout_params.items():
|
||||
@ -155,17 +155,17 @@ class QuantizedTensor(torch.Tensor):
|
||||
tensor_params[k] = v
|
||||
else:
|
||||
non_tensor_params[k] = v
|
||||
|
||||
|
||||
ctx["tensor_param_keys"] = list(tensor_params.keys())
|
||||
ctx["non_tensor_params"] = non_tensor_params
|
||||
|
||||
|
||||
for k, v in tensor_params.items():
|
||||
attr_name = f"_layout_param_{k}"
|
||||
object.__setattr__(self, attr_name, v)
|
||||
inner_tensors.append(attr_name)
|
||||
|
||||
|
||||
return inner_tensors, ctx
|
||||
|
||||
|
||||
@staticmethod
|
||||
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
|
||||
"""
|
||||
@ -174,41 +174,41 @@ class QuantizedTensor(torch.Tensor):
|
||||
"""
|
||||
layout_type = ctx["layout_type"]
|
||||
layout_params = dict(ctx["non_tensor_params"])
|
||||
|
||||
|
||||
for key in ctx["tensor_param_keys"]:
|
||||
attr_name = f"_layout_param_{key}"
|
||||
layout_params[key] = inner_tensors[attr_name]
|
||||
|
||||
|
||||
return QuantizedTensor(inner_tensors["_q_data"], layout_type, layout_params)
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
|
||||
qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs)
|
||||
return cls(qdata, layout_type, layout_params)
|
||||
|
||||
|
||||
def dequantize(self) -> torch.Tensor:
|
||||
return self._layout_type.dequantize(self._qdata, **self._layout_params)
|
||||
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
|
||||
|
||||
# Step 1: Check generic utilities first (detach, clone, to, etc.)
|
||||
if func in _GENERIC_UTILS:
|
||||
return _GENERIC_UTILS[func](func, args, kwargs)
|
||||
|
||||
|
||||
# Step 2: Check layout-specific handlers (linear, matmul, etc.)
|
||||
layout_type = _get_layout_from_args(args)
|
||||
if layout_type and func in _LAYOUT_REGISTRY:
|
||||
handler = _LAYOUT_REGISTRY[func].get(layout_type)
|
||||
if handler:
|
||||
return handler(func, args, kwargs)
|
||||
|
||||
|
||||
# Step 3: Fallback to dequantization
|
||||
if isinstance(args[0] if args else None, QuantizedTensor):
|
||||
logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
|
||||
return cls._dequant_and_fallback(func, args, kwargs)
|
||||
|
||||
|
||||
@classmethod
|
||||
def _dequant_and_fallback(cls, func, args, kwargs):
|
||||
def dequant_arg(arg):
|
||||
@ -217,7 +217,7 @@ class QuantizedTensor(torch.Tensor):
|
||||
elif isinstance(arg, (list, tuple)):
|
||||
return type(arg)(dequant_arg(a) for a in arg)
|
||||
return arg
|
||||
|
||||
|
||||
new_args = dequant_arg(args)
|
||||
new_kwargs = dequant_arg(kwargs)
|
||||
return func(*new_args, **new_kwargs)
|
||||
@ -239,13 +239,13 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=
|
||||
f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
|
||||
f"but not supported for quantized tensors. Ignoring dtype."
|
||||
)
|
||||
|
||||
|
||||
if target_layout is not None and target_layout != torch.strided:
|
||||
logging.warning(
|
||||
f"QuantizedTensor: layout change requested to {target_layout}, "
|
||||
f"but not supported. Ignoring layout."
|
||||
)
|
||||
|
||||
|
||||
# Handle device transfer
|
||||
current_device = qt._qdata.device
|
||||
if target_device is not None:
|
||||
@ -254,7 +254,7 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=
|
||||
target_device = torch.device(target_device)
|
||||
if isinstance(current_device, str):
|
||||
current_device = torch.device(current_device)
|
||||
|
||||
|
||||
if target_device != current_device:
|
||||
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
|
||||
new_q_data = qt._qdata.to(device=target_device)
|
||||
@ -262,7 +262,7 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=
|
||||
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
|
||||
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
|
||||
return new_qt
|
||||
|
||||
|
||||
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
|
||||
return qt
|
||||
|
||||
@ -318,7 +318,7 @@ def generic_to_dtype_layout(func, args, kwargs):
|
||||
def generic_copy_(func, args, kwargs):
|
||||
qt_dest = args[0]
|
||||
src = args[1]
|
||||
|
||||
|
||||
if isinstance(qt_dest, QuantizedTensor):
|
||||
if isinstance(src, QuantizedTensor):
|
||||
# Copy from another quantized tensor
|
||||
@ -383,15 +383,15 @@ def fp8_linear(func, args, kwargs):
|
||||
input_tensor = args[0]
|
||||
weight = args[1]
|
||||
bias = args[2] if len(args) > 2 else None
|
||||
|
||||
|
||||
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
||||
|
||||
|
||||
out_dtype = kwargs.get("out_dtype")
|
||||
if out_dtype is None:
|
||||
out_dtype = input_tensor._layout_params['orig_dtype']
|
||||
|
||||
|
||||
weight_t = plain_weight.t()
|
||||
|
||||
tensor_2d = False
|
||||
@ -424,7 +424,7 @@ def fp8_linear(func, args, kwargs):
|
||||
return QuantizedTensor(output, TensorCoreFP8Layout, output_params)
|
||||
else:
|
||||
return output
|
||||
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ class SimpleModel(torch.nn.Module):
|
||||
self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16)
|
||||
self.layer2 = operations.Linear(20, 30, device="cpu", dtype=torch.bfloat16)
|
||||
self.layer3 = operations.Linear(30, 40, device="cpu", dtype=torch.bfloat16)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x = self.layer1(x)
|
||||
x = torch.nn.functional.relu(x)
|
||||
@ -32,10 +32,10 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
"""Test that model with no quantization works normally"""
|
||||
# Configure no quantization
|
||||
ops.MixedPrecisionOps._layer_quant_config = {}
|
||||
|
||||
|
||||
# Create model
|
||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||
|
||||
|
||||
# Initialize weights manually
|
||||
model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16))
|
||||
model.layer1.bias = torch.nn.Parameter(torch.randn(20, dtype=torch.bfloat16))
|
||||
@ -43,19 +43,19 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
model.layer2.bias = torch.nn.Parameter(torch.randn(30, dtype=torch.bfloat16))
|
||||
model.layer3.weight = torch.nn.Parameter(torch.randn(40, 30, dtype=torch.bfloat16))
|
||||
model.layer3.bias = torch.nn.Parameter(torch.randn(40, dtype=torch.bfloat16))
|
||||
|
||||
|
||||
# Initialize weight_function and bias_function
|
||||
for layer in [model.layer1, model.layer2, model.layer3]:
|
||||
layer.weight_function = []
|
||||
layer.bias_function = []
|
||||
|
||||
|
||||
# Forward pass
|
||||
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||
output = model(input_tensor)
|
||||
|
||||
|
||||
self.assertEqual(output.shape, (5, 40))
|
||||
self.assertEqual(output.dtype, torch.bfloat16)
|
||||
|
||||
|
||||
def test_mixed_precision_load(self):
|
||||
"""Test loading a mixed precision model from state dict"""
|
||||
# Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard
|
||||
@ -70,52 +70,52 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
}
|
||||
}
|
||||
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||
|
||||
|
||||
# Create state dict with mixed precision
|
||||
fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
|
||||
|
||||
state_dict = {
|
||||
# Layer 1: FP8 E4M3FN
|
||||
"layer1.weight": fp8_weight1,
|
||||
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
|
||||
"layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32),
|
||||
|
||||
|
||||
# Layer 2: Standard BF16
|
||||
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
|
||||
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
|
||||
|
||||
|
||||
# Layer 3: FP8 E4M3FN
|
||||
"layer3.weight": fp8_weight3,
|
||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||
"layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32),
|
||||
}
|
||||
|
||||
|
||||
# Create model and load state dict (strict=False because custom loading pops keys)
|
||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
|
||||
# Verify weights are wrapped in QuantizedTensor
|
||||
self.assertIsInstance(model.layer1.weight, QuantizedTensor)
|
||||
self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout)
|
||||
|
||||
|
||||
# Layer 2 should NOT be quantized
|
||||
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
|
||||
|
||||
|
||||
# Layer 3 should be quantized
|
||||
self.assertIsInstance(model.layer3.weight, QuantizedTensor)
|
||||
self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout)
|
||||
|
||||
|
||||
# Verify scales were loaded
|
||||
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
|
||||
self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5)
|
||||
|
||||
|
||||
# Forward pass
|
||||
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||
output = model(input_tensor)
|
||||
|
||||
|
||||
self.assertEqual(output.shape, (5, 40))
|
||||
|
||||
|
||||
def test_state_dict_quantized_preserved(self):
|
||||
"""Test that quantized weights are preserved in state_dict()"""
|
||||
# Configure mixed precision
|
||||
@ -126,7 +126,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
}
|
||||
}
|
||||
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||
|
||||
|
||||
# Create and load model
|
||||
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
state_dict1 = {
|
||||
@ -138,22 +138,22 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
|
||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||
}
|
||||
|
||||
|
||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||
model.load_state_dict(state_dict1, strict=False)
|
||||
|
||||
|
||||
# Save state dict
|
||||
state_dict2 = model.state_dict()
|
||||
|
||||
|
||||
# Verify layer1.weight is a QuantizedTensor with scale preserved
|
||||
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
|
||||
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
|
||||
self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout)
|
||||
|
||||
|
||||
# Verify non-quantized layers are standard tensors
|
||||
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)
|
||||
self.assertNotIsInstance(state_dict2["layer3.weight"], QuantizedTensor)
|
||||
|
||||
|
||||
def test_weight_function_compatibility(self):
|
||||
"""Test that weight_function (LoRA) works with quantized layers"""
|
||||
# Configure FP8 quantization
|
||||
@ -164,7 +164,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
}
|
||||
}
|
||||
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||
|
||||
|
||||
# Create and load model
|
||||
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
state_dict = {
|
||||
@ -176,24 +176,24 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
|
||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||
}
|
||||
|
||||
|
||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
|
||||
# Add a weight function (simulating LoRA)
|
||||
# This should trigger dequantization during forward pass
|
||||
def apply_lora(weight):
|
||||
lora_delta = torch.randn_like(weight) * 0.01
|
||||
return weight + lora_delta
|
||||
|
||||
|
||||
model.layer1.weight_function.append(apply_lora)
|
||||
|
||||
|
||||
# Forward pass should work with LoRA (triggers weight_function path)
|
||||
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||
output = model(input_tensor)
|
||||
|
||||
|
||||
self.assertEqual(output.shape, (5, 40))
|
||||
|
||||
|
||||
def test_error_handling_unknown_format(self):
|
||||
"""Test that unknown formats raise error"""
|
||||
# Configure with unknown format
|
||||
@ -204,7 +204,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
}
|
||||
}
|
||||
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||
|
||||
|
||||
# Create state dict
|
||||
state_dict = {
|
||||
"layer1.weight": torch.randn(20, 10, dtype=torch.bfloat16),
|
||||
@ -214,7 +214,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
|
||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||
}
|
||||
|
||||
|
||||
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
|
||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||
with self.assertRaises(KeyError):
|
||||
|
||||
@ -11,51 +11,51 @@ from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
||||
|
||||
class TestQuantizedTensor(unittest.TestCase):
|
||||
"""Test the QuantizedTensor subclass with FP8 layout"""
|
||||
|
||||
|
||||
def test_creation(self):
|
||||
"""Test creating a QuantizedTensor with TensorCoreFP8Layout"""
|
||||
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
scale = torch.tensor(2.0)
|
||||
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
|
||||
|
||||
|
||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
||||
|
||||
|
||||
self.assertIsInstance(qt, QuantizedTensor)
|
||||
self.assertEqual(qt.shape, (256, 128))
|
||||
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
||||
self.assertEqual(qt._layout_params['scale'], scale)
|
||||
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
|
||||
self.assertEqual(qt._layout_type, TensorCoreFP8Layout)
|
||||
|
||||
|
||||
def test_dequantize(self):
|
||||
"""Test explicit dequantization"""
|
||||
|
||||
fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
scale = torch.tensor(3.0)
|
||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||
|
||||
|
||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
||||
dequantized = qt.dequantize()
|
||||
|
||||
|
||||
self.assertEqual(dequantized.dtype, torch.float32)
|
||||
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))
|
||||
|
||||
|
||||
def test_from_float(self):
|
||||
"""Test creating QuantizedTensor from float tensor"""
|
||||
float_tensor = torch.randn(64, 32, dtype=torch.float32)
|
||||
scale = torch.tensor(1.5)
|
||||
|
||||
|
||||
qt = QuantizedTensor.from_float(
|
||||
float_tensor,
|
||||
TensorCoreFP8Layout,
|
||||
float_tensor,
|
||||
TensorCoreFP8Layout,
|
||||
scale=scale,
|
||||
fp8_dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
|
||||
self.assertIsInstance(qt, QuantizedTensor)
|
||||
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
||||
self.assertEqual(qt.shape, (64, 32))
|
||||
|
||||
|
||||
# Verify dequantization gives approximately original values
|
||||
dequantized = qt.dequantize()
|
||||
mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean()
|
||||
@ -64,48 +64,48 @@ class TestQuantizedTensor(unittest.TestCase):
|
||||
|
||||
class TestGenericUtilities(unittest.TestCase):
|
||||
"""Test generic utility operations"""
|
||||
|
||||
|
||||
def test_detach(self):
|
||||
"""Test detach operation on quantized tensor"""
|
||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
scale = torch.tensor(1.5)
|
||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
||||
|
||||
|
||||
# Detach should return a new QuantizedTensor
|
||||
qt_detached = qt.detach()
|
||||
|
||||
|
||||
self.assertIsInstance(qt_detached, QuantizedTensor)
|
||||
self.assertEqual(qt_detached.shape, qt.shape)
|
||||
self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout)
|
||||
|
||||
|
||||
def test_clone(self):
|
||||
"""Test clone operation on quantized tensor"""
|
||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
scale = torch.tensor(1.5)
|
||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
||||
|
||||
|
||||
# Clone should return a new QuantizedTensor
|
||||
qt_cloned = qt.clone()
|
||||
|
||||
|
||||
self.assertIsInstance(qt_cloned, QuantizedTensor)
|
||||
self.assertEqual(qt_cloned.shape, qt.shape)
|
||||
self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout)
|
||||
|
||||
|
||||
# Verify it's a deep copy
|
||||
self.assertIsNot(qt_cloned._qdata, qt._qdata)
|
||||
|
||||
|
||||
def test_to_device(self):
|
||||
"""Test device transfer"""
|
||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
scale = torch.tensor(1.5)
|
||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
||||
|
||||
|
||||
# Moving to same device should work (CPU to CPU)
|
||||
qt_cpu = qt.to('cpu')
|
||||
|
||||
|
||||
self.assertIsInstance(qt_cpu, QuantizedTensor)
|
||||
self.assertEqual(qt_cpu.device.type, 'cpu')
|
||||
self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu')
|
||||
@ -113,64 +113,63 @@ class TestGenericUtilities(unittest.TestCase):
|
||||
|
||||
class TestTensorCoreFP8Layout(unittest.TestCase):
|
||||
"""Test the TensorCoreFP8Layout implementation"""
|
||||
|
||||
|
||||
def test_quantize(self):
|
||||
"""Test quantization method"""
|
||||
float_tensor = torch.randn(32, 64, dtype=torch.float32)
|
||||
scale = torch.tensor(1.5)
|
||||
|
||||
|
||||
qdata, layout_params = TensorCoreFP8Layout.quantize(
|
||||
float_tensor,
|
||||
scale=scale,
|
||||
fp8_dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
|
||||
self.assertEqual(qdata.dtype, torch.float8_e4m3fn)
|
||||
self.assertEqual(qdata.shape, float_tensor.shape)
|
||||
self.assertIn('scale', layout_params)
|
||||
self.assertIn('orig_dtype', layout_params)
|
||||
self.assertEqual(layout_params['orig_dtype'], torch.float32)
|
||||
|
||||
|
||||
def test_dequantize(self):
|
||||
"""Test dequantization method"""
|
||||
float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0
|
||||
scale = torch.tensor(1.0)
|
||||
|
||||
|
||||
qdata, layout_params = TensorCoreFP8Layout.quantize(
|
||||
float_tensor,
|
||||
scale=scale,
|
||||
fp8_dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
|
||||
dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params)
|
||||
|
||||
|
||||
# Should approximately match original
|
||||
self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1))
|
||||
|
||||
|
||||
class TestFallbackMechanism(unittest.TestCase):
|
||||
"""Test fallback for unsupported operations"""
|
||||
|
||||
|
||||
def test_unsupported_op_dequantizes(self):
|
||||
"""Test that unsupported operations fall back to dequantization"""
|
||||
# Set seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
|
||||
# Create quantized tensor
|
||||
a_fp32 = torch.randn(10, 20, dtype=torch.float32)
|
||||
scale = torch.tensor(1.0)
|
||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||
a_q = QuantizedTensor.from_float(
|
||||
a_fp32,
|
||||
TensorCoreFP8Layout,
|
||||
scale=scale,
|
||||
fp8_dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
|
||||
# Call an operation that doesn't have a registered handler
|
||||
# For example, torch.abs
|
||||
result = torch.abs(a_q)
|
||||
|
||||
|
||||
# Should work via fallback (dequantize → abs → return)
|
||||
self.assertNotIsInstance(result, QuantizedTensor)
|
||||
expected = torch.abs(a_fp32)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user