diff --git a/comfy/float.py b/comfy/float.py index 1dbdafd59..57fd07099 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -1,7 +1,18 @@ import torch +import math + +def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None): + mantissa_scaled = torch.where( + normal_mask, + (abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS), + (abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS))) + ) + + mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator) + return mantissa_scaled.floor() / (2**MANTISSA_BITS) #Not 100% sure about this -def manual_stochastic_round_to_float8(x, dtype): +def manual_stochastic_round_to_float8(x, dtype, generator=None): if dtype == torch.float8_e4m3fn: EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7 elif dtype == torch.float8_e5m2: @@ -9,44 +20,34 @@ def manual_stochastic_round_to_float8(x, dtype): else: raise ValueError("Unsupported dtype") + x = x.half() sign = torch.sign(x) abs_x = x.abs() + sign = torch.where(abs_x == 0, 0, sign) # Combine exponent calculation and clamping exponent = torch.clamp( - torch.floor(torch.log2(abs_x)).to(torch.int32) + EXPONENT_BIAS, + torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS, 0, 2**EXPONENT_BITS - 1 ) # Combine mantissa calculation and rounding - # min_normal = 2.0 ** (-EXPONENT_BIAS + 1) - # zero_mask = (abs_x == 0) - # subnormal_mask = (exponent == 0) & (abs_x != 0) normal_mask = ~(exponent == 0) - mantissa_scaled = torch.where( + abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator) + + sign *= torch.where( normal_mask, - (abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS), - (abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS))) - ) - mantissa_floor = mantissa_scaled.floor() - mantissa = torch.where( - torch.rand_like(mantissa_scaled) < (mantissa_scaled - mantissa_floor), - (mantissa_floor + 1) / (2**MANTISSA_BITS), - mantissa_floor / (2**MANTISSA_BITS) - ) - result = torch.where( - normal_mask, - sign * (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + mantissa), - sign * (2.0 ** (-EXPONENT_BIAS + 1)) * mantissa + (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x), + (2.0 ** (-EXPONENT_BIAS + 1)) * abs_x ) + del abs_x - result = torch.where(abs_x == 0, 0, result) - return result.to(dtype=dtype) + return sign.to(dtype=dtype) -def stochastic_rounding(value, dtype): +def stochastic_rounding(value, dtype, seed=0): if dtype == torch.float32: return value.to(dtype=torch.float32) if dtype == torch.float16: @@ -54,6 +55,8 @@ def stochastic_rounding(value, dtype): if dtype == torch.bfloat16: return value.to(dtype=torch.bfloat16) if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: - return manual_stochastic_round_to_float8(value, dtype) + generator = torch.Generator(device=value.device) + generator.manual_seed(seed) + return manual_stochastic_round_to_float8(value, dtype, generator=generator) return value.to(dtype=dtype) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 3cb013acf..20928e294 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -32,6 +32,18 @@ from .model_base import BaseModel from .model_management_types import ModelManageable, MemoryMeasurements from .types import UnetWrapperFunction +def string_to_seed(data): + crc = 0xFFFFFFFF + for byte in data: + if isinstance(byte, str): + byte = ord(byte) + crc ^= byte + for _ in range(8): + if crc & 1: + crc = (crc >> 1) ^ 0xEDB88320 + else: + crc >>= 1 + return crc ^ 0xFFFFFFFF def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): to = model_options["transformer_options"].copy() @@ -313,7 +325,7 @@ class ModelPatcher(ModelManageable): else: temp_weight = weight.to(torch.float32, copy=True) out_weight = lora.calculate_weight(self.patches[key], temp_weight, key) - out_weight = stochastic_rounding(out_weight, weight.dtype) + out_weight = stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) if inplace_update: utils.copy_to_param(self.model, key, out_weight) else: @@ -323,12 +335,21 @@ class ModelPatcher(ModelManageable): mem_counter = 0 patch_counter = 0 lowvram_counter = 0 - load_completely = [] + loading = [] for n, m in self.model.named_modules(): + if hasattr(m, "comfy_cast_weights") or hasattr(m, "weight"): + loading.append((model_management.module_size(m), n, m)) + + load_completely = [] + loading.sort(reverse=True) + for x in loading: + n = x[1] + m = x[2] + module_mem = x[0] + lowvram_weight = False if not full_load and hasattr(m, "comfy_cast_weights"): - module_mem = model_management.module_size(m) if mem_counter + module_mem >= lowvram_model_memory: lowvram_weight = True lowvram_counter += 1 @@ -360,9 +381,8 @@ class ModelPatcher(ModelManageable): wipe_lowvram_weight(m) if hasattr(m, "weight"): - mem_used = model_management.module_size(m) - mem_counter += mem_used - load_completely.append((mem_used, n, m)) + mem_counter += module_mem + load_completely.append((module_mem, n, m)) load_completely.sort(reverse=True) for x in load_completely: