From dd31609c0e85201b8e5f84c23de6681205038f5b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 13 Jan 2026 19:33:05 -0500 Subject: [PATCH] Optimize nvfp4 lora applying. --- comfy/float.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/comfy/float.py b/comfy/float.py index 1a6070bff..8c303bea0 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -161,10 +161,7 @@ def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0): block_size = 16 x = x.reshape(orig_shape[0], -1, block_size) - max_abs = torch.amax(torch.abs(x), dim=-1) - block_scale = max_abs / F4_E2M1_MAX - scaled_block_scales = block_scale / per_tensor_scale.to(block_scale.dtype) - scaled_block_scales_fp8 = torch.clamp(scaled_block_scales, max=F8_E4M3_MAX).to(torch.float8_e4m3fn) + scaled_block_scales_fp8 = torch.clamp(((torch.amax(torch.abs(x), dim=-1)) / F4_E2M1_MAX) / per_tensor_scale.to(x.dtype), max=F8_E4M3_MAX).to(torch.float8_e4m3fn) x /= (per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)).unsqueeze(-1) generator = torch.Generator(device=x.device) @@ -172,6 +169,5 @@ def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0): x = x.view(orig_shape).nan_to_num() data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator) - blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False) return data_lp, blocked_scales