Merge branch 'comfyanonymous:master' into master

This commit is contained in:
patientx 2025-02-14 11:00:12 +03:00 committed by GitHub
commit 4d66aa9709
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 20 additions and 4 deletions

View File

@ -218,6 +218,7 @@ class ModelPatcher:
self.load_device = load_device
self.offload_device = offload_device
self.weight_inplace_update = weight_inplace_update
self.force_cast_weights = False
self.patches_uuid = uuid.uuid4()
self.parent = None
@ -277,6 +278,8 @@ class ModelPatcher:
n.object_patches_backup = self.object_patches_backup
n.parent = self
n.force_cast_weights = self.force_cast_weights
# attachments
n.attachments = {}
for k in self.attachments:
@ -424,6 +427,12 @@ class ModelPatcher:
def add_object_patch(self, name, obj):
self.object_patches[name] = obj
def set_model_compute_dtype(self, dtype):
self.add_object_patch("manual_cast_dtype", dtype)
if dtype is not None:
self.force_cast_weights = True
self.patches_uuid = uuid.uuid4() #TODO: optimize by preventing a full model reload for this
def add_weight_wrapper(self, name, function):
self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function]
self.patches_uuid = uuid.uuid4()
@ -602,6 +611,7 @@ class ModelPatcher:
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
continue
cast_weight = self.force_cast_weights
if lowvram_weight:
if hasattr(m, "comfy_cast_weights"):
m.weight_function = []
@ -620,8 +630,7 @@ class ModelPatcher:
m.bias_function = [LowVramPatch(bias_key, self.patches)]
patch_counter += 1
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
cast_weight = True
else:
if hasattr(m, "comfy_cast_weights"):
wipe_lowvram_weight(m)
@ -630,6 +639,10 @@ class ModelPatcher:
mem_counter += module_mem
load_completely.append((module_mem, n, m, params))
if cast_weight:
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
if weight_key in self.weight_wrapper_patches:
m.weight_function.extend(self.weight_wrapper_patches[weight_key])
@ -766,6 +779,7 @@ class ModelPatcher:
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if move_weight:
cast_weight = self.force_cast_weights
m.to(device_to)
module_mem += move_weight_functions(m, device_to)
if lowvram_possible:
@ -775,7 +789,9 @@ class ModelPatcher:
if bias_key in self.patches:
m.bias_function.append(LowVramPatch(bias_key, self.patches))
patch_counter += 1
cast_weight = True
if cast_weight:
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
m.comfy_patched_weights = False

View File

@ -114,7 +114,7 @@ def test_load_extra_model_paths_expands_userpath(
mock_yaml_safe_load.assert_called_once()
# Check if open was called with the correct file path
mock_file.assert_called_once_with(dummy_yaml_file_name, 'r')
mock_file.assert_called_once_with(dummy_yaml_file_name, 'r', encoding='utf-8')
@patch('builtins.open', new_callable=mock_open)

View File

@ -4,7 +4,7 @@ import folder_paths
import logging
def load_extra_path_config(yaml_path):
with open(yaml_path, 'r') as stream:
with open(yaml_path, 'r', encoding='utf-8') as stream:
config = yaml.safe_load(stream)
yaml_dir = os.path.dirname(os.path.abspath(yaml_path))
for c in config: