Merge branch 'comfyanonymous:master' into master

This commit is contained in:
patientx 2024-08-26 22:56:02 +03:00 committed by GitHub
commit 4193c15afe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 39 additions and 24 deletions

View File

@ -1,7 +1,18 @@
import torch
import math
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
mantissa_scaled = torch.where(
normal_mask,
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
)
mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator)
return mantissa_scaled.floor() / (2**MANTISSA_BITS)
#Not 100% sure about this
def manual_stochastic_round_to_float8(x, dtype):
def manual_stochastic_round_to_float8(x, dtype, generator=None):
if dtype == torch.float8_e4m3fn:
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
elif dtype == torch.float8_e5m2:
@ -9,44 +20,34 @@ def manual_stochastic_round_to_float8(x, dtype):
else:
raise ValueError("Unsupported dtype")
x = x.half()
sign = torch.sign(x)
abs_x = x.abs()
sign = torch.where(abs_x == 0, 0, sign)
# Combine exponent calculation and clamping
exponent = torch.clamp(
torch.floor(torch.log2(abs_x)).to(torch.int32) + EXPONENT_BIAS,
torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS,
0, 2**EXPONENT_BITS - 1
)
# Combine mantissa calculation and rounding
# min_normal = 2.0 ** (-EXPONENT_BIAS + 1)
# zero_mask = (abs_x == 0)
# subnormal_mask = (exponent == 0) & (abs_x != 0)
normal_mask = ~(exponent == 0)
mantissa_scaled = torch.where(
abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator)
sign *= torch.where(
normal_mask,
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
)
mantissa_floor = mantissa_scaled.floor()
mantissa = torch.where(
torch.rand_like(mantissa_scaled) < (mantissa_scaled - mantissa_floor),
(mantissa_floor + 1) / (2**MANTISSA_BITS),
mantissa_floor / (2**MANTISSA_BITS)
)
result = torch.where(
normal_mask,
sign * (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + mantissa),
sign * (2.0 ** (-EXPONENT_BIAS + 1)) * mantissa
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
)
del abs_x
result = torch.where(abs_x == 0, 0, result)
return result.to(dtype=dtype)
return sign.to(dtype=dtype)
def stochastic_rounding(value, dtype):
def stochastic_rounding(value, dtype, seed=0):
if dtype == torch.float32:
return value.to(dtype=torch.float32)
if dtype == torch.float16:
@ -54,6 +55,8 @@ def stochastic_rounding(value, dtype):
if dtype == torch.bfloat16:
return value.to(dtype=torch.bfloat16)
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
return manual_stochastic_round_to_float8(value, dtype)
generator = torch.Generator(device=value.device)
generator.manual_seed(seed)
return manual_stochastic_round_to_float8(value, dtype, generator=generator)
return value.to(dtype=dtype)

View File

@ -30,6 +30,18 @@ import comfy.model_management
import comfy.lora
from comfy.types import UnetWrapperFunction
def string_to_seed(data):
crc = 0xFFFFFFFF
for byte in data:
if isinstance(byte, str):
byte = ord(byte)
crc ^= byte
for _ in range(8):
if crc & 1:
crc = (crc >> 1) ^ 0xEDB88320
else:
crc >>= 1
return crc ^ 0xFFFFFFFF
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
to = model_options["transformer_options"].copy()
@ -309,7 +321,7 @@ class ModelPatcher:
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype)
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
if inplace_update:
comfy.utils.copy_to_param(self.model, key, out_weight)
else: