From 59a2e8c74e1df09c5590abd95e4908c9704a687c Mon Sep 17 00:00:00 2001 From: lspindler Date: Tue, 28 Oct 2025 07:33:19 +0100 Subject: [PATCH] Rename quant dtype parameter --- comfy/ops.py | 2 +- comfy/quant_ops.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index d7a8873e2..93731eedf 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -379,7 +379,7 @@ def fp8_linear(self, input): # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype} quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight) - quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, fp8_dtype=dtype) + quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype) o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) if tensor_2d: diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index aa1a231bd..b14e03084 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -347,20 +347,20 @@ class TensorCoreFP8Layout(QuantizedLayout): - orig_dtype: Original dtype before quantization (for casting back) """ @classmethod - def quantize(cls, tensor, scale=None, fp8_dtype=torch.float8_e4m3fn): + def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn): orig_dtype = tensor.dtype if scale is None: - scale = torch.amax(tensor.abs()) / torch.finfo(fp8_dtype).max + scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max if not isinstance(scale, torch.Tensor): scale = torch.tensor(scale) scale = scale.to(device=tensor.device, dtype=torch.float32) - lp_amax = torch.finfo(fp8_dtype).max + lp_amax = torch.finfo(dtype).max tensor_scaled = tensor.float() / scale torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) - qdata = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format) + qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format) layout_params = { 'scale': scale,