From b3c0e4de57bfd27e3dd94bd9723bb4c714668a09 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 12 Jan 2026 19:33:54 -0800 Subject: [PATCH] Make loras work on nvfp4 models. (#11837) The initial applying is a bit slow but will probably be sped up in the future. --- comfy/float.py | 113 +++++++++++++++++++++++++++++++++++++++++++++ comfy/ops.py | 2 +- comfy/quant_ops.py | 37 ++++++++++++++- requirements.txt | 2 +- 4 files changed, 150 insertions(+), 4 deletions(-) diff --git a/comfy/float.py b/comfy/float.py index 521316fd2..e638b1ff7 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -65,3 +65,116 @@ def stochastic_rounding(value, dtype, seed=0): return output return value.to(dtype=dtype) + + +# TODO: improve this? +def stochastic_float_to_fp4_e2m1(x, generator): + sign = torch.signbit(x).to(torch.uint8) + x_abs = x.abs() + + exp = torch.floor(torch.log2(x_abs) + 1.0).clamp(0, 3) + x += (torch.rand(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) - 0.5) * (2 ** (exp - 2.0)) * 1.25 + + x_abs = x.abs() + exp = torch.floor(torch.log2(x_abs) + 1.1925).clamp(0, 3) + + mantissa = torch.where( + exp > 0, + (x_abs / (2.0 ** (exp - 1)) - 1.0) * 2.0, + (x_abs * 2.0) + ).round().to(torch.uint8) + + fp4 = (sign << 3) | (exp.to(torch.uint8) << 1) | mantissa + + fp4_flat = fp4.view(-1) + packed = (fp4_flat[0::2] << 4) | fp4_flat[1::2] + return packed.reshape(list(x.shape)[:-1] + [-1]) + + +def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor: + """ + Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern. + See: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + + Args: + input_matrix: Input tensor of shape (H, W) + Returns: + Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4)) + """ + + def ceil_div(a, b): + return (a + b - 1) // b + + rows, cols = input_matrix.shape + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + + # Calculate the padded shape + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + + padded = input_matrix + if (rows, cols) != (padded_rows, padded_cols): + padded = torch.zeros( + (padded_rows, padded_cols), + device=input_matrix.device, + dtype=input_matrix.dtype, + ) + padded[:rows, :cols] = input_matrix + + # Rearrange the blocks + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + if flatten: + return rearranged.flatten() + + return rearranged.reshape(padded_rows, padded_cols) + + +def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0): + F4_E2M1_MAX = 6.0 + F8_E4M3_MAX = 448.0 + + def roundup(x: int, multiple: int) -> int: + """Round up x to the nearest multiple.""" + return ((x + multiple - 1) // multiple) * multiple + + orig_shape = x.shape + + # Handle padding + if pad_16x: + rows, cols = x.shape + padded_rows = roundup(rows, 16) + padded_cols = roundup(cols, 16) + if padded_rows != rows or padded_cols != cols: + x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows)) + # Note: We update orig_shape because the output tensor logic below assumes x.shape matches + # what we want to produce. If we pad here, we want the padded output. + orig_shape = x.shape + + block_size = 16 + + x = x.reshape(orig_shape[0], -1, block_size) + max_abs = torch.amax(torch.abs(x), dim=-1) + block_scale = max_abs / F4_E2M1_MAX + scaled_block_scales = block_scale / per_tensor_scale.to(block_scale.dtype) + scaled_block_scales_fp8 = torch.clamp(scaled_block_scales, max=F8_E4M3_MAX).to(torch.float8_e4m3fn) + total_scale = per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype) + + # Handle zero blocks (from padding): avoid 0/0 NaN + zero_scale_mask = (total_scale == 0) + total_scale_safe = torch.where(zero_scale_mask, torch.ones_like(total_scale), total_scale) + + x = x / total_scale_safe.unsqueeze(-1) + + generator = torch.Generator(device=x.device) + generator.manual_seed(seed) + + x = torch.where(zero_scale_mask.unsqueeze(-1), torch.zeros_like(x), x) + + x = x.view(orig_shape) + data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator) + + blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False) + return data_lp, blocked_scales diff --git a/comfy/ops.py b/comfy/ops.py index 9c0b54ff4..415c39e92 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -699,7 +699,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): if getattr(self, 'layout_type', None) is not None: # dtype is now implicit in the layout class - weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True) + weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype) else: weight = weight.to(self.weight.dtype) if return_weight: diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 8324be42a..7a61203c3 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -7,7 +7,7 @@ try: QuantizedTensor, QuantizedLayout, TensorCoreFP8Layout as _CKFp8Layout, - TensorCoreNVFP4Layout, # Direct import, no wrapper needed + TensorCoreNVFP4Layout as _CKNvfp4Layout, register_layout_op, register_layout_class, get_layout_class, @@ -34,7 +34,7 @@ except ImportError as e: class _CKFp8Layout: pass - class TensorCoreNVFP4Layout: + class _CKNvfp4Layout: pass def register_layout_class(name, cls): @@ -84,6 +84,39 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout): return qdata, params +class TensorCoreNVFP4Layout(_CKNvfp4Layout): + @classmethod + def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False): + if tensor.dim() != 2: + raise ValueError(f"NVFP4 requires 2D tensor, got {tensor.dim()}D") + + orig_dtype = tensor.dtype + orig_shape = tuple(tensor.shape) + + if scale is None or (isinstance(scale, str) and scale == "recalculate"): + scale = torch.amax(tensor.abs()) / (ck.float_utils.F8_E4M3_MAX * ck.float_utils.F4_E2M1_MAX) + + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale) + scale = scale.to(device=tensor.device, dtype=torch.float32) + + padded_shape = cls.get_padded_shape(orig_shape) + needs_padding = padded_shape != orig_shape + + if stochastic_rounding > 0: + qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding) + else: + qdata, block_scale = ck.quantize_nvfp4(tensor, scale, pad_16x=needs_padding) + + params = cls.Params( + scale=scale, + orig_dtype=orig_dtype, + orig_shape=orig_shape, + block_scale=block_scale, + ) + return qdata, params + + class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase): FP8_DTYPE = torch.float8_e4m3fn diff --git a/requirements.txt b/requirements.txt index 077c8930a..43737056e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,7 +21,7 @@ psutil alembic SQLAlchemy av>=14.2.0 -comfy-kitchen>=0.2.5 +comfy-kitchen>=0.2.6 #non essential dependencies: kornia>=0.7.1