mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-13 07:40:50 +08:00
Merge branch 'master' of github.com:comfyanonymous/ComfyUI
This commit is contained in:
commit
8615c86722
@ -1,7 +1,18 @@
|
||||
import torch
|
||||
import math
|
||||
|
||||
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
|
||||
mantissa_scaled = torch.where(
|
||||
normal_mask,
|
||||
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
|
||||
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
|
||||
)
|
||||
|
||||
mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator)
|
||||
return mantissa_scaled.floor() / (2**MANTISSA_BITS)
|
||||
|
||||
#Not 100% sure about this
|
||||
def manual_stochastic_round_to_float8(x, dtype):
|
||||
def manual_stochastic_round_to_float8(x, dtype, generator=None):
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
|
||||
elif dtype == torch.float8_e5m2:
|
||||
@ -9,44 +20,34 @@ def manual_stochastic_round_to_float8(x, dtype):
|
||||
else:
|
||||
raise ValueError("Unsupported dtype")
|
||||
|
||||
x = x.half()
|
||||
sign = torch.sign(x)
|
||||
abs_x = x.abs()
|
||||
sign = torch.where(abs_x == 0, 0, sign)
|
||||
|
||||
# Combine exponent calculation and clamping
|
||||
exponent = torch.clamp(
|
||||
torch.floor(torch.log2(abs_x)).to(torch.int32) + EXPONENT_BIAS,
|
||||
torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS,
|
||||
0, 2**EXPONENT_BITS - 1
|
||||
)
|
||||
|
||||
# Combine mantissa calculation and rounding
|
||||
# min_normal = 2.0 ** (-EXPONENT_BIAS + 1)
|
||||
# zero_mask = (abs_x == 0)
|
||||
# subnormal_mask = (exponent == 0) & (abs_x != 0)
|
||||
normal_mask = ~(exponent == 0)
|
||||
|
||||
mantissa_scaled = torch.where(
|
||||
abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator)
|
||||
|
||||
sign *= torch.where(
|
||||
normal_mask,
|
||||
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
|
||||
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
|
||||
)
|
||||
mantissa_floor = mantissa_scaled.floor()
|
||||
mantissa = torch.where(
|
||||
torch.rand_like(mantissa_scaled) < (mantissa_scaled - mantissa_floor),
|
||||
(mantissa_floor + 1) / (2**MANTISSA_BITS),
|
||||
mantissa_floor / (2**MANTISSA_BITS)
|
||||
)
|
||||
result = torch.where(
|
||||
normal_mask,
|
||||
sign * (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + mantissa),
|
||||
sign * (2.0 ** (-EXPONENT_BIAS + 1)) * mantissa
|
||||
(2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
|
||||
(2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
|
||||
)
|
||||
del abs_x
|
||||
|
||||
result = torch.where(abs_x == 0, 0, result)
|
||||
return result.to(dtype=dtype)
|
||||
return sign.to(dtype=dtype)
|
||||
|
||||
|
||||
|
||||
def stochastic_rounding(value, dtype):
|
||||
def stochastic_rounding(value, dtype, seed=0):
|
||||
if dtype == torch.float32:
|
||||
return value.to(dtype=torch.float32)
|
||||
if dtype == torch.float16:
|
||||
@ -54,6 +55,8 @@ def stochastic_rounding(value, dtype):
|
||||
if dtype == torch.bfloat16:
|
||||
return value.to(dtype=torch.bfloat16)
|
||||
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
||||
return manual_stochastic_round_to_float8(value, dtype)
|
||||
generator = torch.Generator(device=value.device)
|
||||
generator.manual_seed(seed)
|
||||
return manual_stochastic_round_to_float8(value, dtype, generator=generator)
|
||||
|
||||
return value.to(dtype=dtype)
|
||||
|
||||
@ -32,6 +32,18 @@ from .model_base import BaseModel
|
||||
from .model_management_types import ModelManageable, MemoryMeasurements
|
||||
from .types import UnetWrapperFunction
|
||||
|
||||
def string_to_seed(data):
|
||||
crc = 0xFFFFFFFF
|
||||
for byte in data:
|
||||
if isinstance(byte, str):
|
||||
byte = ord(byte)
|
||||
crc ^= byte
|
||||
for _ in range(8):
|
||||
if crc & 1:
|
||||
crc = (crc >> 1) ^ 0xEDB88320
|
||||
else:
|
||||
crc >>= 1
|
||||
return crc ^ 0xFFFFFFFF
|
||||
|
||||
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
||||
to = model_options["transformer_options"].copy()
|
||||
@ -313,7 +325,7 @@ class ModelPatcher(ModelManageable):
|
||||
else:
|
||||
temp_weight = weight.to(torch.float32, copy=True)
|
||||
out_weight = lora.calculate_weight(self.patches[key], temp_weight, key)
|
||||
out_weight = stochastic_rounding(out_weight, weight.dtype)
|
||||
out_weight = stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
|
||||
if inplace_update:
|
||||
utils.copy_to_param(self.model, key, out_weight)
|
||||
else:
|
||||
@ -323,12 +335,21 @@ class ModelPatcher(ModelManageable):
|
||||
mem_counter = 0
|
||||
patch_counter = 0
|
||||
lowvram_counter = 0
|
||||
load_completely = []
|
||||
loading = []
|
||||
for n, m in self.model.named_modules():
|
||||
if hasattr(m, "comfy_cast_weights") or hasattr(m, "weight"):
|
||||
loading.append((model_management.module_size(m), n, m))
|
||||
|
||||
load_completely = []
|
||||
loading.sort(reverse=True)
|
||||
for x in loading:
|
||||
n = x[1]
|
||||
m = x[2]
|
||||
module_mem = x[0]
|
||||
|
||||
lowvram_weight = False
|
||||
|
||||
if not full_load and hasattr(m, "comfy_cast_weights"):
|
||||
module_mem = model_management.module_size(m)
|
||||
if mem_counter + module_mem >= lowvram_model_memory:
|
||||
lowvram_weight = True
|
||||
lowvram_counter += 1
|
||||
@ -360,9 +381,8 @@ class ModelPatcher(ModelManageable):
|
||||
wipe_lowvram_weight(m)
|
||||
|
||||
if hasattr(m, "weight"):
|
||||
mem_used = model_management.module_size(m)
|
||||
mem_counter += mem_used
|
||||
load_completely.append((mem_used, n, m))
|
||||
mem_counter += module_mem
|
||||
load_completely.append((module_mem, n, m))
|
||||
|
||||
load_completely.sort(reverse=True)
|
||||
for x in load_completely:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user