mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-14 16:20:50 +08:00
Optimize nvfp4 lora applying. (#11854)
This commit is contained in:
parent
1419047fdb
commit
15b312de7a
@ -165,20 +165,12 @@ def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
|
||||
block_scale = max_abs / F4_E2M1_MAX
|
||||
scaled_block_scales = block_scale / per_tensor_scale.to(block_scale.dtype)
|
||||
scaled_block_scales_fp8 = torch.clamp(scaled_block_scales, max=F8_E4M3_MAX).to(torch.float8_e4m3fn)
|
||||
total_scale = per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)
|
||||
|
||||
# Handle zero blocks (from padding): avoid 0/0 NaN
|
||||
zero_scale_mask = (total_scale == 0)
|
||||
total_scale_safe = torch.where(zero_scale_mask, torch.ones_like(total_scale), total_scale)
|
||||
|
||||
x = x / total_scale_safe.unsqueeze(-1)
|
||||
x /= (per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)).unsqueeze(-1)
|
||||
|
||||
generator = torch.Generator(device=x.device)
|
||||
generator.manual_seed(seed)
|
||||
|
||||
x = torch.where(zero_scale_mask.unsqueeze(-1), torch.zeros_like(x), x)
|
||||
|
||||
x = x.view(orig_shape)
|
||||
x = x.view(orig_shape).nan_to_num()
|
||||
data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator)
|
||||
|
||||
blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user