From 684296148ed30bc39e34c7c069905e166153c11e Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Fri, 29 May 2026 12:23:42 +1000 Subject: [PATCH] float: use CK stochastic rounding cuda kernel (#13971) --- comfy/float.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/comfy/float.py b/comfy/float.py index 184b3d6d0..3c82d6359 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -1,5 +1,20 @@ +import logging + import torch +_CK_STOCHASTIC_ROUNDING_AVAILABLE = False +try: + import comfy_kitchen as ck + _ck_stochastic_rounding_fp8 = ck.stochastic_rounding_fp8 + _CK_STOCHASTIC_ROUNDING_AVAILABLE = True +except (AttributeError, ImportError): + logging.warning("comfy_kitchen does not support stochastic FP8 rounding, please update comfy_kitchen.") + +if not _CK_STOCHASTIC_ROUNDING_AVAILABLE: + def _ck_stochastic_rounding_fp8(value, rng, dtype): + raise NotImplementedError("comfy_kitchen does not support stochastic FP8 rounding") + + def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None): mantissa_scaled = torch.where( normal_mask, @@ -57,6 +72,10 @@ def stochastic_rounding(value, dtype, seed=0): if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: generator = torch.Generator(device=value.device) generator.manual_seed(seed) + if _CK_STOCHASTIC_ROUNDING_AVAILABLE: + rng = torch.randint(0, 256, value.size(), dtype=torch.uint8, layout=value.layout, device=value.device, generator=generator) + return _ck_stochastic_rounding_fp8(value, rng, dtype) + output = torch.empty_like(value, dtype=dtype) num_slices = max(1, (value.numel() / (4096 * 4096))) slice_size = max(1, round(value.shape[0] / num_slices))