From 6165c38cb58c40b15ade879b80051b6c9148587f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 13 Jan 2026 21:49:38 -0800 Subject: [PATCH] Optimize nvfp4 lora applying. (#11866) This changes results a bit but it also speeds up things a lot. --- comfy/float.py | 56 ++++++++++++++++++++++++++++++++------- comfy/quant_ops.py | 2 +- comfy/supported_models.py | 2 +- 3 files changed, 49 insertions(+), 11 deletions(-) diff --git a/comfy/float.py b/comfy/float.py index 8c303bea0..88c47cd80 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -137,10 +137,44 @@ def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor: return rearranged.reshape(padded_rows, padded_cols) -def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0): +def stochastic_round_quantize_nvfp4_block(x, per_tensor_scale, generator): F4_E2M1_MAX = 6.0 F8_E4M3_MAX = 448.0 + orig_shape = x.shape + + block_size = 16 + + x = x.reshape(orig_shape[0], -1, block_size) + scaled_block_scales_fp8 = torch.clamp(((torch.amax(torch.abs(x), dim=-1)) / F4_E2M1_MAX) / per_tensor_scale.to(x.dtype), max=F8_E4M3_MAX).to(torch.float8_e4m3fn) + x = x / (per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)).unsqueeze(-1) + + x = x.view(orig_shape).nan_to_num() + data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator) + return data_lp, scaled_block_scales_fp8 + + +def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0): + def roundup(x: int, multiple: int) -> int: + """Round up x to the nearest multiple.""" + return ((x + multiple - 1) // multiple) * multiple + + generator = torch.Generator(device=x.device) + generator.manual_seed(seed) + + # Handle padding + if pad_16x: + rows, cols = x.shape + padded_rows = roundup(rows, 16) + padded_cols = roundup(cols, 16) + if padded_rows != rows or padded_cols != cols: + x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows)) + + x, blocked_scaled = stochastic_round_quantize_nvfp4_block(x, per_tensor_scale, generator) + return x, to_blocked(blocked_scaled, flatten=False) + + +def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=0, block_size=4096 * 4096): def roundup(x: int, multiple: int) -> int: """Round up x to the nearest multiple.""" return ((x + multiple - 1) // multiple) * multiple @@ -158,16 +192,20 @@ def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0): # what we want to produce. If we pad here, we want the padded output. orig_shape = x.shape - block_size = 16 + orig_shape = list(orig_shape) - x = x.reshape(orig_shape[0], -1, block_size) - scaled_block_scales_fp8 = torch.clamp(((torch.amax(torch.abs(x), dim=-1)) / F4_E2M1_MAX) / per_tensor_scale.to(x.dtype), max=F8_E4M3_MAX).to(torch.float8_e4m3fn) - x /= (per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)).unsqueeze(-1) + output_fp4 = torch.empty(orig_shape[:-1] + [orig_shape[-1] // 2], dtype=torch.uint8, device=x.device) + output_block = torch.empty(orig_shape[:-1] + [orig_shape[-1] // 16], dtype=torch.float8_e4m3fn, device=x.device) generator = torch.Generator(device=x.device) generator.manual_seed(seed) - x = x.view(orig_shape).nan_to_num() - data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator) - blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False) - return data_lp, blocked_scales + num_slices = max(1, (x.numel() / block_size)) + slice_size = max(1, (round(x.shape[0] / num_slices))) + + for i in range(0, x.shape[0], slice_size): + fp4, block = stochastic_round_quantize_nvfp4_block(x[i: i + slice_size], per_tensor_scale, generator=generator) + output_fp4[i:i + slice_size].copy_(fp4) + output_block[i:i + slice_size].copy_(block) + + return output_fp4, to_blocked(output_block, flatten=False) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 7a61203c3..15a4f457b 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -104,7 +104,7 @@ class TensorCoreNVFP4Layout(_CKNvfp4Layout): needs_padding = padded_shape != orig_shape if stochastic_rounding > 0: - qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding) + qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4_by_block(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding) else: qdata, block_scale = ck.quantize_nvfp4(tensor, scale, pad_16x=needs_padding) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1bf54f13f..2c4c6b8fc 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1042,7 +1042,7 @@ class ZImage(Lumina2): "shift": 3.0, } - memory_usage_factor = 2.0 + memory_usage_factor = 2.8 supported_inference_dtypes = [torch.bfloat16, torch.float32]