diff --git a/comfy/float.py b/comfy/float.py index 184b3d6d0..3c82d6359 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -1,5 +1,20 @@ +import logging + import torch +_CK_STOCHASTIC_ROUNDING_AVAILABLE = False +try: + import comfy_kitchen as ck + _ck_stochastic_rounding_fp8 = ck.stochastic_rounding_fp8 + _CK_STOCHASTIC_ROUNDING_AVAILABLE = True +except (AttributeError, ImportError): + logging.warning("comfy_kitchen does not support stochastic FP8 rounding, please update comfy_kitchen.") + +if not _CK_STOCHASTIC_ROUNDING_AVAILABLE: + def _ck_stochastic_rounding_fp8(value, rng, dtype): + raise NotImplementedError("comfy_kitchen does not support stochastic FP8 rounding") + + def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None): mantissa_scaled = torch.where( normal_mask, @@ -57,6 +72,10 @@ def stochastic_rounding(value, dtype, seed=0): if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: generator = torch.Generator(device=value.device) generator.manual_seed(seed) + if _CK_STOCHASTIC_ROUNDING_AVAILABLE: + rng = torch.randint(0, 256, value.size(), dtype=torch.uint8, layout=value.layout, device=value.device, generator=generator) + return _ck_stochastic_rounding_fp8(value, rng, dtype) + output = torch.empty_like(value, dtype=dtype) num_slices = max(1, (value.numel() / (4096 * 4096))) slice_size = max(1, round(value.shape[0] / num_slices)) diff --git a/requirements.txt b/requirements.txt index e76ed0034..0617667e1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.44.19 -comfyui-workflow-templates==0.9.85 +comfyui-workflow-templates==0.9.91 comfyui-embedded-docs==0.5.1 torch torchsde @@ -22,7 +22,7 @@ alembic SQLAlchemy>=2.0.0 filelock av>=16.0.0 -comfy-kitchen>=0.2.8 +comfy-kitchen==0.2.9 comfy-aimdo==0.4.5 requests simpleeval>=1.0.0