From 15b312de7a74a836fa45b989a7697895b01e0cbf Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 13 Jan 2026 16:23:58 -0800 Subject: [PATCH] Optimize nvfp4 lora applying. (#11854) --- comfy/float.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/comfy/float.py b/comfy/float.py index c806af76b..1a6070bff 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -165,20 +165,12 @@ def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0): block_scale = max_abs / F4_E2M1_MAX scaled_block_scales = block_scale / per_tensor_scale.to(block_scale.dtype) scaled_block_scales_fp8 = torch.clamp(scaled_block_scales, max=F8_E4M3_MAX).to(torch.float8_e4m3fn) - total_scale = per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype) - - # Handle zero blocks (from padding): avoid 0/0 NaN - zero_scale_mask = (total_scale == 0) - total_scale_safe = torch.where(zero_scale_mask, torch.ones_like(total_scale), total_scale) - - x = x / total_scale_safe.unsqueeze(-1) + x /= (per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)).unsqueeze(-1) generator = torch.Generator(device=x.device) generator.manual_seed(seed) - x = torch.where(zero_scale_mask.unsqueeze(-1), torch.zeros_like(x), x) - - x = x.view(orig_shape) + x = x.view(orig_shape).nan_to_num() data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator) blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False)