mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-15 16:50:57 +08:00
Optimize nvfp4 lora applying.
This commit is contained in:
parent
15b312de7a
commit
dd31609c0e
@ -161,10 +161,7 @@ def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
|
||||
block_size = 16
|
||||
|
||||
x = x.reshape(orig_shape[0], -1, block_size)
|
||||
max_abs = torch.amax(torch.abs(x), dim=-1)
|
||||
block_scale = max_abs / F4_E2M1_MAX
|
||||
scaled_block_scales = block_scale / per_tensor_scale.to(block_scale.dtype)
|
||||
scaled_block_scales_fp8 = torch.clamp(scaled_block_scales, max=F8_E4M3_MAX).to(torch.float8_e4m3fn)
|
||||
scaled_block_scales_fp8 = torch.clamp(((torch.amax(torch.abs(x), dim=-1)) / F4_E2M1_MAX) / per_tensor_scale.to(x.dtype), max=F8_E4M3_MAX).to(torch.float8_e4m3fn)
|
||||
x /= (per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)).unsqueeze(-1)
|
||||
|
||||
generator = torch.Generator(device=x.device)
|
||||
@ -172,6 +169,5 @@ def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
|
||||
|
||||
x = x.view(orig_shape).nan_to_num()
|
||||
data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator)
|
||||
|
||||
blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False)
|
||||
return data_lp, blocked_scales
|
||||
|
||||
Loading…
Reference in New Issue
Block a user