mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-26 03:23:34 +08:00
Add stochastic rounding for LoRAs
This commit is contained in:
parent
9eceec64d7
commit
7220240517
@ -209,3 +209,39 @@ def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=
|
|||||||
output_block[i:i + slice_size].copy_(block)
|
output_block[i:i + slice_size].copy_(block)
|
||||||
|
|
||||||
return output_fp4, to_blocked(output_block, flatten=False)
|
return output_fp4, to_blocked(output_block, flatten=False)
|
||||||
|
|
||||||
|
|
||||||
|
def stochastic_round_quantize_mxfp8_by_block(x, pad_32x, seed=0):
|
||||||
|
def roundup(x_val, multiple):
|
||||||
|
return ((x_val + multiple - 1) // multiple) * multiple
|
||||||
|
|
||||||
|
if pad_32x:
|
||||||
|
rows, cols = x.shape
|
||||||
|
padded_rows = roundup(rows, 32)
|
||||||
|
padded_cols = roundup(cols, 32)
|
||||||
|
if padded_rows != rows or padded_cols != cols:
|
||||||
|
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
|
||||||
|
|
||||||
|
F8_E4M3_MAX = 448.0
|
||||||
|
E8M0_BIAS = 127
|
||||||
|
BLOCK_SIZE = 32
|
||||||
|
|
||||||
|
rows, cols = x.shape
|
||||||
|
x_blocked = x.reshape(rows, -1, BLOCK_SIZE)
|
||||||
|
max_abs = torch.amax(torch.abs(x_blocked), dim=-1)
|
||||||
|
|
||||||
|
# E8M0 block scales (power-of-2 exponents)
|
||||||
|
scale_needed = torch.clamp(max_abs.float() / F8_E4M3_MAX, min=2**(-127))
|
||||||
|
exp_biased = torch.clamp(torch.ceil(torch.log2(scale_needed)).to(torch.int32) + E8M0_BIAS, 0, 254)
|
||||||
|
block_scales_e8m0 = exp_biased.to(torch.uint8)
|
||||||
|
|
||||||
|
zero_mask = (max_abs == 0)
|
||||||
|
block_scales_f32 = (block_scales_e8m0.to(torch.int32) << 23).view(torch.float32)
|
||||||
|
block_scales_f32 = torch.where(zero_mask, torch.ones_like(block_scales_f32), block_scales_f32)
|
||||||
|
|
||||||
|
# Scale per-block then stochastic round
|
||||||
|
data_scaled = (x_blocked.float() / block_scales_f32.unsqueeze(-1)).reshape(rows, cols)
|
||||||
|
output_fp8 = stochastic_rounding(data_scaled, torch.float8_e4m3fn, seed=seed)
|
||||||
|
|
||||||
|
block_scales_e8m0 = torch.where(zero_mask, torch.zeros_like(block_scales_e8m0), block_scales_e8m0)
|
||||||
|
return output_fp8, to_blocked(block_scales_e8m0, flatten=False).view(torch.float8_e8m0fnu)
|
||||||
|
|||||||
@ -8,7 +8,7 @@ try:
|
|||||||
QuantizedLayout,
|
QuantizedLayout,
|
||||||
TensorCoreFP8Layout as _CKFp8Layout,
|
TensorCoreFP8Layout as _CKFp8Layout,
|
||||||
TensorCoreNVFP4Layout as _CKNvfp4Layout,
|
TensorCoreNVFP4Layout as _CKNvfp4Layout,
|
||||||
TensorCoreMXFP8Layout,
|
TensorCoreMXFP8Layout as _CKMxfp8Layout,
|
||||||
register_layout_op,
|
register_layout_op,
|
||||||
register_layout_class,
|
register_layout_class,
|
||||||
get_layout_class,
|
get_layout_class,
|
||||||
@ -38,7 +38,7 @@ except ImportError as e:
|
|||||||
class _CKNvfp4Layout:
|
class _CKNvfp4Layout:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class TensorCoreMXFP8Layout:
|
class _CKMxfp8Layout:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def register_layout_class(name, cls):
|
def register_layout_class(name, cls):
|
||||||
@ -88,6 +88,31 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
|
|||||||
return qdata, params
|
return qdata, params
|
||||||
|
|
||||||
|
|
||||||
|
class TensorCoreMXFP8Layout(_CKMxfp8Layout):
|
||||||
|
@classmethod
|
||||||
|
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||||
|
if tensor.dim() != 2:
|
||||||
|
raise ValueError(f"MXFP8 requires 2D tensor, got {tensor.dim()}D")
|
||||||
|
|
||||||
|
orig_dtype = tensor.dtype
|
||||||
|
orig_shape = tuple(tensor.shape)
|
||||||
|
|
||||||
|
padded_shape = cls.get_padded_shape(orig_shape)
|
||||||
|
needs_padding = padded_shape != orig_shape
|
||||||
|
|
||||||
|
if stochastic_rounding > 0:
|
||||||
|
qdata, block_scale = comfy.float.stochastic_round_quantize_mxfp8_by_block(tensor, pad_32x=needs_padding, seed=stochastic_rounding)
|
||||||
|
else:
|
||||||
|
qdata, block_scale = ck.quantize_mxfp8(tensor, pad_32x=needs_padding)
|
||||||
|
|
||||||
|
params = cls.Params(
|
||||||
|
scale=block_scale,
|
||||||
|
orig_dtype=orig_dtype,
|
||||||
|
orig_shape=orig_shape,
|
||||||
|
)
|
||||||
|
return qdata, params
|
||||||
|
|
||||||
|
|
||||||
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
|
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
|
||||||
@classmethod
|
@classmethod
|
||||||
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user