Merge branch 'comfyanonymous:master' into master

This commit is contained in:
patientx 2024-08-19 22:54:45 +03:00 committed by GitHub
commit 9baf36e97b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 51 additions and 38 deletions

View File

@ -1,5 +1,8 @@
blank_issues_enabled: true
contact_links:
- name: ComfyUI Frontend Issues
url: https://github.com/Comfy-Org/ComfyUI_frontend/issues
about: Issues related to the ComfyUI frontend (display issues, user interaction bugs), please go to the frontend repo to file the issue
- name: ComfyUI Matrix Space
url: https://app.element.io/#/room/%23comfyui_space%3Amatrix.org
about: The ComfyUI Matrix Space is available for support and general discussion related to ComfyUI (Matrix is like Discord but open source).

View File

@ -19,25 +19,29 @@ def manual_stochastic_round_to_float8(x, dtype):
)
# Combine mantissa calculation and rounding
mantissa = abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0
mantissa_scaled = mantissa * (2**MANTISSA_BITS)
# min_normal = 2.0 ** (-EXPONENT_BIAS + 1)
# zero_mask = (abs_x == 0)
# subnormal_mask = (exponent == 0) & (abs_x != 0)
normal_mask = ~(exponent == 0)
mantissa_scaled = torch.where(
normal_mask,
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
)
mantissa_floor = mantissa_scaled.floor()
mantissa = torch.where(
torch.rand_like(mantissa_scaled) < (mantissa_scaled - mantissa_floor),
(mantissa_floor + 1) / (2**MANTISSA_BITS),
mantissa_floor / (2**MANTISSA_BITS)
)
result = torch.where(
normal_mask,
sign * (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + mantissa),
sign * (2.0 ** (-EXPONENT_BIAS + 1)) * mantissa
)
# Combine final result calculation
result = sign * (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + mantissa)
# Handle zero case
zero_mask = (abs_x == 0)
result = torch.where(zero_mask, torch.zeros_like(result), result)
# Handle subnormal numbers
min_normal = 2.0 ** (-EXPONENT_BIAS + 1)
result = torch.where((abs_x < min_normal) & (~zero_mask), torch.round(x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS))) * (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)), result)
result = torch.where(abs_x == 0, 0, result)
return result.to(dtype=dtype)

View File

@ -644,40 +644,46 @@ class ModelPatcher:
def partially_unload(self, device_to, memory_to_free=0):
memory_freed = 0
patch_counter = 0
unload_list = []
for n, m in list(self.model.named_modules())[::-1]:
if memory_to_free < memory_freed:
break
for n, m in self.model.named_modules():
shift_lowvram = False
if hasattr(m, "comfy_cast_weights"):
module_mem = comfy.model_management.module_size(m)
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
unload_list.append((module_mem, n, m))
unload_list.sort()
for unload in unload_list:
if memory_to_free < memory_freed:
break
module_mem = unload[0]
n = unload[1]
m = unload[2]
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if m.weight is not None and m.weight.device != device_to:
for key in [weight_key, bias_key]:
bk = self.backup.get(key, None)
if bk is not None:
if bk.inplace_update:
comfy.utils.copy_to_param(self.model, key, bk.weight)
else:
comfy.utils.set_attr_param(self.model, key, bk.weight)
self.backup.pop(key)
if m.weight is not None and m.weight.device != device_to:
for key in [weight_key, bias_key]:
bk = self.backup.get(key, None)
if bk is not None:
if bk.inplace_update:
comfy.utils.copy_to_param(self.model, key, bk.weight)
else:
comfy.utils.set_attr_param(self.model, key, bk.weight)
self.backup.pop(key)
m.to(device_to)
if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self)
patch_counter += 1
if bias_key in self.patches:
m.bias_function = LowVramPatch(bias_key, self)
patch_counter += 1
m.to(device_to)
if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self)
patch_counter += 1
if bias_key in self.patches:
m.bias_function = LowVramPatch(bias_key, self)
patch_counter += 1
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
memory_freed += module_mem
logging.debug("freed {}".format(n))
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
memory_freed += module_mem
logging.debug("freed {}".format(n))
self.model.model_lowvram = True
self.model.lowvram_patch_counter += patch_counter