From 0912fd5ec3c156d4de706f6181a3b352430542f1 Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 19 May 2026 10:48:50 +1000 Subject: [PATCH] float: use CK stochastic rounding cuda kernel --- comfy/float.py | 19 +++++++++++++++++++ requirements.txt | 2 +- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/comfy/float.py b/comfy/float.py index 184b3d6d0..3c82d6359 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -1,5 +1,20 @@ +import logging + import torch +_CK_STOCHASTIC_ROUNDING_AVAILABLE = False +try: + import comfy_kitchen as ck + _ck_stochastic_rounding_fp8 = ck.stochastic_rounding_fp8 + _CK_STOCHASTIC_ROUNDING_AVAILABLE = True +except (AttributeError, ImportError): + logging.warning("comfy_kitchen does not support stochastic FP8 rounding, please update comfy_kitchen.") + +if not _CK_STOCHASTIC_ROUNDING_AVAILABLE: + def _ck_stochastic_rounding_fp8(value, rng, dtype): + raise NotImplementedError("comfy_kitchen does not support stochastic FP8 rounding") + + def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None): mantissa_scaled = torch.where( normal_mask, @@ -57,6 +72,10 @@ def stochastic_rounding(value, dtype, seed=0): if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: generator = torch.Generator(device=value.device) generator.manual_seed(seed) + if _CK_STOCHASTIC_ROUNDING_AVAILABLE: + rng = torch.randint(0, 256, value.size(), dtype=torch.uint8, layout=value.layout, device=value.device, generator=generator) + return _ck_stochastic_rounding_fp8(value, rng, dtype) + output = torch.empty_like(value, dtype=dtype) num_slices = max(1, (value.numel() / (4096 * 4096))) slice_size = max(1, round(value.shape[0] / num_slices)) diff --git a/requirements.txt b/requirements.txt index f499a10ae..c748ff0a4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy>=2.0.0 filelock av>=14.2.0 -comfy-kitchen>=0.2.8 +comfy-kitchen>=0.2.9 comfy-aimdo==0.3.0 requests simpleeval>=1.0.0