Make loras work on nvfp4 models. (#11837)

The initial applying is a bit slow but will probably be sped up in the
future.
This commit is contained in:
comfyanonymous 2026-01-12 19:33:54 -08:00 committed by GitHub
parent ecaeeb990d
commit b3c0e4de57
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 150 additions and 4 deletions

View File

@ -65,3 +65,116 @@ def stochastic_rounding(value, dtype, seed=0):
return output
return value.to(dtype=dtype)
# TODO: improve this?
def stochastic_float_to_fp4_e2m1(x, generator):
sign = torch.signbit(x).to(torch.uint8)
x_abs = x.abs()
exp = torch.floor(torch.log2(x_abs) + 1.0).clamp(0, 3)
x += (torch.rand(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) - 0.5) * (2 ** (exp - 2.0)) * 1.25
x_abs = x.abs()
exp = torch.floor(torch.log2(x_abs) + 1.1925).clamp(0, 3)
mantissa = torch.where(
exp > 0,
(x_abs / (2.0 ** (exp - 1)) - 1.0) * 2.0,
(x_abs * 2.0)
).round().to(torch.uint8)
fp4 = (sign << 3) | (exp.to(torch.uint8) << 1) | mantissa
fp4_flat = fp4.view(-1)
packed = (fp4_flat[0::2] << 4) | fp4_flat[1::2]
return packed.reshape(list(x.shape)[:-1] + [-1])
def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor:
"""
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
See:
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
Args:
input_matrix: Input tensor of shape (H, W)
Returns:
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
"""
def ceil_div(a, b):
return (a + b - 1) // b
rows, cols = input_matrix.shape
n_row_blocks = ceil_div(rows, 128)
n_col_blocks = ceil_div(cols, 4)
# Calculate the padded shape
padded_rows = n_row_blocks * 128
padded_cols = n_col_blocks * 4
padded = input_matrix
if (rows, cols) != (padded_rows, padded_cols):
padded = torch.zeros(
(padded_rows, padded_cols),
device=input_matrix.device,
dtype=input_matrix.dtype,
)
padded[:rows, :cols] = input_matrix
# Rearrange the blocks
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
if flatten:
return rearranged.flatten()
return rearranged.reshape(padded_rows, padded_cols)
def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
F4_E2M1_MAX = 6.0
F8_E4M3_MAX = 448.0
def roundup(x: int, multiple: int) -> int:
"""Round up x to the nearest multiple."""
return ((x + multiple - 1) // multiple) * multiple
orig_shape = x.shape
# Handle padding
if pad_16x:
rows, cols = x.shape
padded_rows = roundup(rows, 16)
padded_cols = roundup(cols, 16)
if padded_rows != rows or padded_cols != cols:
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
# Note: We update orig_shape because the output tensor logic below assumes x.shape matches
# what we want to produce. If we pad here, we want the padded output.
orig_shape = x.shape
block_size = 16
x = x.reshape(orig_shape[0], -1, block_size)
max_abs = torch.amax(torch.abs(x), dim=-1)
block_scale = max_abs / F4_E2M1_MAX
scaled_block_scales = block_scale / per_tensor_scale.to(block_scale.dtype)
scaled_block_scales_fp8 = torch.clamp(scaled_block_scales, max=F8_E4M3_MAX).to(torch.float8_e4m3fn)
total_scale = per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)
# Handle zero blocks (from padding): avoid 0/0 NaN
zero_scale_mask = (total_scale == 0)
total_scale_safe = torch.where(zero_scale_mask, torch.ones_like(total_scale), total_scale)
x = x / total_scale_safe.unsqueeze(-1)
generator = torch.Generator(device=x.device)
generator.manual_seed(seed)
x = torch.where(zero_scale_mask.unsqueeze(-1), torch.zeros_like(x), x)
x = x.view(orig_shape)
data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator)
blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False)
return data_lp, blocked_scales

View File

@ -699,7 +699,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
if getattr(self, 'layout_type', None) is not None:
# dtype is now implicit in the layout class
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True)
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
else:
weight = weight.to(self.weight.dtype)
if return_weight:

View File

@ -7,7 +7,7 @@ try:
QuantizedTensor,
QuantizedLayout,
TensorCoreFP8Layout as _CKFp8Layout,
TensorCoreNVFP4Layout, # Direct import, no wrapper needed
TensorCoreNVFP4Layout as _CKNvfp4Layout,
register_layout_op,
register_layout_class,
get_layout_class,
@ -34,7 +34,7 @@ except ImportError as e:
class _CKFp8Layout:
pass
class TensorCoreNVFP4Layout:
class _CKNvfp4Layout:
pass
def register_layout_class(name, cls):
@ -84,6 +84,39 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
return qdata, params
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
@classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
if tensor.dim() != 2:
raise ValueError(f"NVFP4 requires 2D tensor, got {tensor.dim()}D")
orig_dtype = tensor.dtype
orig_shape = tuple(tensor.shape)
if scale is None or (isinstance(scale, str) and scale == "recalculate"):
scale = torch.amax(tensor.abs()) / (ck.float_utils.F8_E4M3_MAX * ck.float_utils.F4_E2M1_MAX)
if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32)
padded_shape = cls.get_padded_shape(orig_shape)
needs_padding = padded_shape != orig_shape
if stochastic_rounding > 0:
qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding)
else:
qdata, block_scale = ck.quantize_nvfp4(tensor, scale, pad_16x=needs_padding)
params = cls.Params(
scale=scale,
orig_dtype=orig_dtype,
orig_shape=orig_shape,
block_scale=block_scale,
)
return qdata, params
class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase):
FP8_DTYPE = torch.float8_e4m3fn

View File

@ -21,7 +21,7 @@ psutil
alembic
SQLAlchemy
av>=14.2.0
comfy-kitchen>=0.2.5
comfy-kitchen>=0.2.6
#non essential dependencies:
kornia>=0.7.1