From 45beebd33cd086f1b46e7e7054ba065d3a999cfe Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 20 Jun 2023 17:34:11 -0400 Subject: [PATCH 1/4] Add a type of model patch useful for model merging. --- comfy/sd.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index e6cda5131..0ff918cb9 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -347,15 +347,23 @@ class ModelPatcher: def model_dtype(self): return self.model.get_dtype() - def add_patches(self, patches, strength=1.0): + def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): p = {} model_sd = self.model.state_dict() for k in patches: if k in model_sd: p[k] = patches[k] - self.patches += [(strength, p)] + self.patches += [(strength_patch, p, strength_model)] return p.keys() + def model_state_dict(self): + sd = self.model.state_dict() + keys = list(sd.keys()) + for k in keys: + if not k.startswith("diffusion_model."): + sd.pop(k) + return sd + def patch_model(self): model_sd = self.model.state_dict() for p in self.patches: @@ -371,8 +379,14 @@ class ModelPatcher: self.backup[key] = weight.clone() alpha = p[0] + strength_model = p[2] - if len(v) == 4: #lora/locon + if strength_model != 1.0: + weight *= strength_model + + if len(v) == 1: + weight += alpha * (v[0]).type(weight.dtype).to(weight.device) + elif len(v) == 4: #lora/locon mat1 = v[0] mat2 = v[1] if v[2] is not None: From 8125b51a628e4733099c4765ce5ad4478ab85518 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 20 Jun 2023 19:08:48 -0400 Subject: [PATCH 2/4] Keep a set of model_keys for faster add_patches. --- comfy/sd.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 0ff918cb9..097fbb200 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -302,12 +302,14 @@ class ModelPatcher: t = model_sd[k] size += t.nelement() * t.element_size() self.size = size + self.model_keys = set(model_sd.keys()) return size def clone(self): n = ModelPatcher(self.model, self.size) n.patches = self.patches[:] n.model_options = copy.deepcopy(self.model_options) + n.model_keys = self.model_keys return n def set_model_tomesd(self, ratio): @@ -349,9 +351,8 @@ class ModelPatcher: def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): p = {} - model_sd = self.model.state_dict() for k in patches: - if k in model_sd: + if k in self.model_keys: p[k] = patches[k] self.patches += [(strength_patch, p, strength_model)] return p.keys() @@ -365,7 +366,7 @@ class ModelPatcher: return sd def patch_model(self): - model_sd = self.model.state_dict() + model_sd = self.model_state_dict() for p in self.patches: for k in p[1]: v = p[1][k] From bf3f27177529cd7b82ba4776cf99ed90d090081e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 20 Jun 2023 19:17:03 -0400 Subject: [PATCH 3/4] Add some nodes for basic model merging. --- comfy_extras/nodes_model_merging.py | 55 +++++++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 56 insertions(+) create mode 100644 comfy_extras/nodes_model_merging.py diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py new file mode 100644 index 000000000..daf4b09ba --- /dev/null +++ b/comfy_extras/nodes_model_merging.py @@ -0,0 +1,55 @@ + + +class ModelMergeSimple: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model1": ("MODEL",), + "model2": ("MODEL",), + "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "merge" + + CATEGORY = "_for_testing/model_merging" + + def merge(self, model1, model2, ratio): + m = model1.clone() + sd = model2.model_state_dict() + for k in sd: + m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio) + return (m, ) + +class ModelMergeBlocks: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model1": ("MODEL",), + "model2": ("MODEL",), + "input": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "middle": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "merge" + + CATEGORY = "_for_testing/model_merging" + + def merge(self, model1, model2, **kwargs): + m = model1.clone() + sd = model2.model_state_dict() + default_ratio = next(iter(kwargs.values())) + + for k in sd: + ratio = default_ratio + k_unet = k[len("diffusion_model."):] + + for arg in kwargs: + if k_unet.startswith(arg): + ratio = kwargs[arg] + + m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "ModelMergeSimple": ModelMergeSimple, + "ModelMergeBlocks": ModelMergeBlocks +} diff --git a/nodes.py b/nodes.py index cbb7d69ea..396abe308 100644 --- a/nodes.py +++ b/nodes.py @@ -1459,4 +1459,5 @@ def init_custom_nodes(): load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py")) + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_model_merging.py")) load_custom_nodes() From 51581dbfa9ce19f537e4cd110509ac5ab91dd74c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 20 Jun 2023 19:37:43 -0400 Subject: [PATCH 4/4] Fix last commits causing an issue with the text encoder lora. --- comfy/sd.py | 11 ++++++----- comfy_extras/nodes_model_merging.py | 4 ++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 097fbb200..e016bea07 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -357,12 +357,13 @@ class ModelPatcher: self.patches += [(strength_patch, p, strength_model)] return p.keys() - def model_state_dict(self): + def model_state_dict(self, filter_prefix=None): sd = self.model.state_dict() keys = list(sd.keys()) - for k in keys: - if not k.startswith("diffusion_model."): - sd.pop(k) + if filter_prefix is not None: + for k in keys: + if not k.startswith(filter_prefix): + sd.pop(k) return sd def patch_model(self): @@ -443,7 +444,7 @@ class ModelPatcher: weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype).to(weight.device) return self.model def unpatch_model(self): - model_sd = self.model.state_dict() + model_sd = self.model_state_dict() keys = list(self.backup.keys()) for k in keys: model_sd[k][:] = self.backup[k] diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index daf4b09ba..52b73f702 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -14,7 +14,7 @@ class ModelMergeSimple: def merge(self, model1, model2, ratio): m = model1.clone() - sd = model2.model_state_dict() + sd = model2.model_state_dict("diffusion_model.") for k in sd: m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio) return (m, ) @@ -35,7 +35,7 @@ class ModelMergeBlocks: def merge(self, model1, model2, **kwargs): m = model1.clone() - sd = model2.model_state_dict() + sd = model2.model_state_dict("diffusion_model.") default_ratio = next(iter(kwargs.values())) for k in sd: