diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 571d3f760..cd96541d7 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -399,7 +399,10 @@ class TensorCoreFP8Layout(QuantizedLayout): orig_dtype = tensor.dtype if isinstance(scale, str) and scale == "recalculate": - scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max + scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(dtype).max + if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small + tensor_info = torch.finfo(tensor.dtype) + scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max)) if scale is not None: if not isinstance(scale, torch.Tensor):