mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-01 20:07:37 +08:00
float: use CK stochastic rounding cuda kernel (#13971)
Some checks are pending
Detect Unreviewed Merge / detect (push) Waiting to run
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Some checks are pending
Detect Unreviewed Merge / detect (push) Waiting to run
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
This commit is contained in:
parent
ade4dfd96a
commit
684296148e
@ -1,5 +1,20 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
_CK_STOCHASTIC_ROUNDING_AVAILABLE = False
|
||||||
|
try:
|
||||||
|
import comfy_kitchen as ck
|
||||||
|
_ck_stochastic_rounding_fp8 = ck.stochastic_rounding_fp8
|
||||||
|
_CK_STOCHASTIC_ROUNDING_AVAILABLE = True
|
||||||
|
except (AttributeError, ImportError):
|
||||||
|
logging.warning("comfy_kitchen does not support stochastic FP8 rounding, please update comfy_kitchen.")
|
||||||
|
|
||||||
|
if not _CK_STOCHASTIC_ROUNDING_AVAILABLE:
|
||||||
|
def _ck_stochastic_rounding_fp8(value, rng, dtype):
|
||||||
|
raise NotImplementedError("comfy_kitchen does not support stochastic FP8 rounding")
|
||||||
|
|
||||||
|
|
||||||
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
|
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
|
||||||
mantissa_scaled = torch.where(
|
mantissa_scaled = torch.where(
|
||||||
normal_mask,
|
normal_mask,
|
||||||
@ -57,6 +72,10 @@ def stochastic_rounding(value, dtype, seed=0):
|
|||||||
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
||||||
generator = torch.Generator(device=value.device)
|
generator = torch.Generator(device=value.device)
|
||||||
generator.manual_seed(seed)
|
generator.manual_seed(seed)
|
||||||
|
if _CK_STOCHASTIC_ROUNDING_AVAILABLE:
|
||||||
|
rng = torch.randint(0, 256, value.size(), dtype=torch.uint8, layout=value.layout, device=value.device, generator=generator)
|
||||||
|
return _ck_stochastic_rounding_fp8(value, rng, dtype)
|
||||||
|
|
||||||
output = torch.empty_like(value, dtype=dtype)
|
output = torch.empty_like(value, dtype=dtype)
|
||||||
num_slices = max(1, (value.numel() / (4096 * 4096)))
|
num_slices = max(1, (value.numel() / (4096 * 4096)))
|
||||||
slice_size = max(1, round(value.shape[0] / num_slices))
|
slice_size = max(1, round(value.shape[0] / num_slices))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user