From 019c7029ea324517ab88d7e61e79b739bc8f4e91 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 13 Feb 2025 20:34:03 -0500 Subject: [PATCH 1/2] Add a way to set a different compute dtype for the model at runtime. Currently only works for diffusion models. --- comfy/model_patcher.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index aee0164c5..4dbe1b7aa 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -218,6 +218,7 @@ class ModelPatcher: self.load_device = load_device self.offload_device = offload_device self.weight_inplace_update = weight_inplace_update + self.force_cast_weights = False self.patches_uuid = uuid.uuid4() self.parent = None @@ -277,6 +278,8 @@ class ModelPatcher: n.object_patches_backup = self.object_patches_backup n.parent = self + n.force_cast_weights = self.force_cast_weights + # attachments n.attachments = {} for k in self.attachments: @@ -424,6 +427,12 @@ class ModelPatcher: def add_object_patch(self, name, obj): self.object_patches[name] = obj + def set_model_compute_dtype(self, dtype): + self.add_object_patch("manual_cast_dtype", dtype) + if dtype is not None: + self.force_cast_weights = True + self.patches_uuid = uuid.uuid4() #TODO: optimize by preventing a full model reload for this + def add_weight_wrapper(self, name, function): self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function] self.patches_uuid = uuid.uuid4() @@ -602,6 +611,7 @@ class ModelPatcher: if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed continue + cast_weight = self.force_cast_weights if lowvram_weight: if hasattr(m, "comfy_cast_weights"): m.weight_function = [] @@ -620,8 +630,7 @@ class ModelPatcher: m.bias_function = [LowVramPatch(bias_key, self.patches)] patch_counter += 1 - m.prev_comfy_cast_weights = m.comfy_cast_weights - m.comfy_cast_weights = True + cast_weight = True else: if hasattr(m, "comfy_cast_weights"): wipe_lowvram_weight(m) @@ -630,6 +639,10 @@ class ModelPatcher: mem_counter += module_mem load_completely.append((module_mem, n, m, params)) + if cast_weight: + m.prev_comfy_cast_weights = m.comfy_cast_weights + m.comfy_cast_weights = True + if weight_key in self.weight_wrapper_patches: m.weight_function.extend(self.weight_wrapper_patches[weight_key]) @@ -766,6 +779,7 @@ class ModelPatcher: weight_key = "{}.weight".format(n) bias_key = "{}.bias".format(n) if move_weight: + cast_weight = self.force_cast_weights m.to(device_to) module_mem += move_weight_functions(m, device_to) if lowvram_possible: @@ -775,7 +789,9 @@ class ModelPatcher: if bias_key in self.patches: m.bias_function.append(LowVramPatch(bias_key, self.patches)) patch_counter += 1 + cast_weight = True + if cast_weight: m.prev_comfy_cast_weights = m.comfy_cast_weights m.comfy_cast_weights = True m.comfy_patched_weights = False From 042a905c3791e466327764b79748cf26738a4c26 Mon Sep 17 00:00:00 2001 From: Robin Huang Date: Thu, 13 Feb 2025 17:39:04 -0800 Subject: [PATCH 2/2] Open yaml files with utf-8 encoding for extra_model_paths.yaml (#6807) * Using utf-8 encoding for yaml files. * Fix test assertion. --- tests-unit/utils/extra_config_test.py | 2 +- utils/extra_config.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests-unit/utils/extra_config_test.py b/tests-unit/utils/extra_config_test.py index b23f5bd08..6d232079e 100644 --- a/tests-unit/utils/extra_config_test.py +++ b/tests-unit/utils/extra_config_test.py @@ -114,7 +114,7 @@ def test_load_extra_model_paths_expands_userpath( mock_yaml_safe_load.assert_called_once() # Check if open was called with the correct file path - mock_file.assert_called_once_with(dummy_yaml_file_name, 'r') + mock_file.assert_called_once_with(dummy_yaml_file_name, 'r', encoding='utf-8') @patch('builtins.open', new_callable=mock_open) diff --git a/utils/extra_config.py b/utils/extra_config.py index d7b592855..b7196e36f 100644 --- a/utils/extra_config.py +++ b/utils/extra_config.py @@ -4,7 +4,7 @@ import folder_paths import logging def load_extra_path_config(yaml_path): - with open(yaml_path, 'r') as stream: + with open(yaml_path, 'r', encoding='utf-8') as stream: config = yaml.safe_load(stream) yaml_dir = os.path.dirname(os.path.abspath(yaml_path)) for c in config: