Reduce memory usage for fp8 scaled op. (#10531)

This commit is contained in:
comfyanonymous 2025-10-29 12:43:51 -07:00 committed by GitHub
parent 6c14f3afac
commit 1a58087ac2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -358,7 +358,7 @@ class TensorCoreFP8Layout(QuantizedLayout):
scale = scale.to(device=tensor.device, dtype=torch.float32)
lp_amax = torch.finfo(dtype).max
tensor_scaled = tensor.float() / scale
tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype)
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)